From 8b894933a73f4c477ba1401299c29f3553b622ee Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Wed, 17 Sep 2025 14:40:53 -0700 Subject: [PATCH] engine: add remote proxy (#12307) --- api/client.go | 25 +++- api/types.go | 125 +++++++++++++--- auth/auth.go | 13 ++ cmd/cmd.go | 155 ++++++++++++++++++-- cmd/cmd_test.go | 8 +- envconfig/config.go | 12 ++ server/create.go | 152 ++++++++++++++++--- server/create_test.go | 151 +++++++++++++++++++ server/images.go | 46 ++++-- server/routes.go | 277 +++++++++++++++++++++++++++++++---- server/routes_create_test.go | 74 ++++++++++ server/routes_test.go | 10 +- 12 files changed, 948 insertions(+), 100 deletions(-) diff --git a/api/client.go b/api/client.go index 7cc2acb3d1..20e6d79571 100644 --- a/api/client.go +++ b/api/client.go @@ -222,7 +222,17 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f return fmt.Errorf("unmarshal: %w", err) } - if response.StatusCode >= http.StatusBadRequest { + if response.StatusCode == http.StatusUnauthorized { + pubKey, pkErr := auth.GetPublicKey() + if pkErr != nil { + return pkErr + } + return AuthorizationError{ + StatusCode: response.StatusCode, + Status: response.Status, + PublicKey: pubKey, + } + } else if response.StatusCode >= http.StatusBadRequest { return StatusError{ StatusCode: response.StatusCode, Status: response.Status, @@ -428,3 +438,16 @@ func (c *Client) Version(ctx context.Context) (string, error) { return version.Version, nil } + +// Signout will disconnect an ollama instance from ollama.com +func (c *Client) Signout(ctx context.Context, encodedKey string) error { + return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil) +} + +func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) { + var resp UserResponse + if err := c.do(ctx, http.MethodPost, "/api/me", nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} diff --git a/api/types.go b/api/types.go index df3504c3b2..5b8e034c22 100644 --- a/api/types.go +++ b/api/types.go @@ -11,6 +11,8 @@ import ( "strings" "time" + "github.com/google/uuid" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/types/model" ) @@ -36,6 +38,19 @@ func (e StatusError) Error() string { } } +type AuthorizationError struct { + StatusCode int + Status string + PublicKey string `json:"public_key"` +} + +func (e AuthorizationError) Error() string { + if e.Status != "" { + return e.Status + } + return "something went wrong, please see the ollama server logs for details" +} + // ImageData represents the raw binary data of an image file. type ImageData []byte @@ -313,14 +328,29 @@ func (t *ToolFunction) String() string { // ChatResponse is the response returned by [Client.Chat]. Its fields are // similar to [GenerateResponse]. type ChatResponse struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - Message Message `json:"message"` - DoneReason string `json:"done_reason,omitempty"` - DebugInfo *DebugInfo `json:"_debug_info,omitempty"` + // Model is the model name that generated the response. + Model string `json:"model"` + // RemoteModel is the name of the upstream model that generated the response. + RemoteModel string `json:"remote_model,omitempty"` + + // RemoteHost is the URL of the upstream Ollama host that generated the response. + RemoteHost string `json:"remote_host,omitempty"` + + // CreatedAt is the timestamp of the response. + CreatedAt time.Time `json:"created_at"` + + // Message contains the message or part of a message from the model. + Message Message `json:"message"` + + // Done specifies if the response is complete. Done bool `json:"done"` + // DoneReason is the reason the model stopped generating text. + DoneReason string `json:"done_reason,omitempty"` + + DebugInfo *DebugInfo `json:"_debug_info,omitempty"` + Metrics } @@ -425,20 +455,47 @@ type EmbeddingResponse struct { // CreateRequest is the request passed to [Client.Create]. type CreateRequest struct { - Model string `json:"model"` - Stream *bool `json:"stream,omitempty"` + // Model is the model name to create. + Model string `json:"model"` + + // Stream specifies whether the response is streaming; it is true by default. + Stream *bool `json:"stream,omitempty"` + + // Quantize is the quantization format for the model; leave blank to not change the quantization level. Quantize string `json:"quantize,omitempty"` - From string `json:"from,omitempty"` - Files map[string]string `json:"files,omitempty"` - Adapters map[string]string `json:"adapters,omitempty"` - Template string `json:"template,omitempty"` - License any `json:"license,omitempty"` - System string `json:"system,omitempty"` - Parameters map[string]any `json:"parameters,omitempty"` - Messages []Message `json:"messages,omitempty"` - Renderer string `json:"renderer,omitempty"` - Parser string `json:"parser,omitempty"` + // From is the name of the model or file to use as the source. + From string `json:"from,omitempty"` + + // RemoteHost is the URL of the upstream ollama API for the model (if any). + RemoteHost string `json:"remote_host,omitempty"` + + // Files is a map of files include when creating the model. + Files map[string]string `json:"files,omitempty"` + + // Adapters is a map of LoRA adapters to include when creating the model. + Adapters map[string]string `json:"adapters,omitempty"` + + // Template is the template used when constructing a request to the model. + Template string `json:"template,omitempty"` + + // License is a string or list of strings for licenses. + License any `json:"license,omitempty"` + + // System is the system prompt for the model. + System string `json:"system,omitempty"` + + // Parameters is a map of hyper-parameters which are applied to the model. + Parameters map[string]any `json:"parameters,omitempty"` + + // Messages is a list of messages added to the model before chat and generation requests. + Messages []Message `json:"messages,omitempty"` + + Renderer string `json:"renderer,omitempty"` + Parser string `json:"parser,omitempty"` + + // Info is a map of additional information for the model + Info map[string]any `json:"info,omitempty"` // Deprecated: set the model name with Model instead Name string `json:"name"` @@ -480,6 +537,8 @@ type ShowResponse struct { Parser string `json:"parser,omitempty"` Details ModelDetails `json:"details,omitempty"` Messages []Message `json:"messages,omitempty"` + RemoteModel string `json:"remote_model,omitempty"` + RemoteHost string `json:"remote_host,omitempty"` ModelInfo map[string]any `json:"model_info,omitempty"` ProjectorInfo map[string]any `json:"projector_info,omitempty"` Tensors []Tensor `json:"tensors,omitempty"` @@ -538,12 +597,14 @@ type ProcessResponse struct { // ListModelResponse is a single model description in [ListResponse]. type ListModelResponse struct { - Name string `json:"name"` - Model string `json:"model"` - ModifiedAt time.Time `json:"modified_at"` - Size int64 `json:"size"` - Digest string `json:"digest"` - Details ModelDetails `json:"details,omitempty"` + Name string `json:"name"` + Model string `json:"model"` + RemoteModel string `json:"remote_model,omitempty"` + RemoteHost string `json:"remote_host,omitempty"` + ModifiedAt time.Time `json:"modified_at"` + Size int64 `json:"size"` + Digest string `json:"digest"` + Details ModelDetails `json:"details,omitempty"` } // ProcessModelResponse is a single model description in [ProcessResponse]. @@ -567,6 +628,12 @@ type GenerateResponse struct { // Model is the model name that generated the response. Model string `json:"model"` + // RemoteModel is the name of the upstream model that generated the response. + RemoteModel string `json:"remote_model,omitempty"` + + // RemoteHost is the URL of the upstream Ollama host that generated the response. + RemoteHost string `json:"remote_host,omitempty"` + // CreatedAt is the timestamp of the response. CreatedAt time.Time `json:"created_at"` @@ -604,6 +671,18 @@ type ModelDetails struct { QuantizationLevel string `json:"quantization_level"` } +// UserResponse provides information about a user. +type UserResponse struct { + ID uuid.UUID `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + Bio string `json:"bio,omitempty"` + AvatarURL string `json:"avatarurl,omitempty"` + FirstName string `json:"firstname,omitempty"` + LastName string `json:"lastname,omitempty"` + Plan string `json:"plan,omitempty"` +} + // Tensor describes the metadata for a given tensor. type Tensor struct { Name string `json:"name"` diff --git a/auth/auth.go b/auth/auth.go index e1d8541247..61a8626c34 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -19,6 +19,19 @@ import ( const defaultPrivateKey = "id_ed25519" func keyPath() (string, error) { + fileExists := func(fp string) bool { + info, err := os.Stat(fp) + if err != nil { + return false + } + return !info.IsDir() + } + + systemPath := filepath.Join("/usr/share/ollama/.ollama", defaultPrivateKey) + if fileExists(systemPath) { + return systemPath, nil + } + home, err := os.UserHomeDir() if err != nil { return "", err diff --git a/cmd/cmd.go b/cmd/cmd.go index 19f1e192f7..294e1662fa 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -5,6 +5,7 @@ import ( "context" "crypto/ed25519" "crypto/rand" + "encoding/base64" "encoding/json" "encoding/pem" "errors" @@ -14,6 +15,7 @@ import ( "math" "net" "net/http" + "net/url" "os" "os/signal" "path/filepath" @@ -35,6 +37,7 @@ import ( "golang.org/x/term" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/auth" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/parser" @@ -47,6 +50,8 @@ import ( "github.com/ollama/ollama/version" ) +const ConnectInstructions = "To sign in, navigate to:\n https://ollama.com/connect?name=%s&key=%s\n\n" + // ensureThinkingSupport emits a warning if the model does not advertise thinking support func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) { if name == "" { @@ -286,7 +291,17 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error { Think: opts.Think, } - return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil }) + return client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error { + if r.RemoteModel != "" && opts.ShowConnect { + p.StopAndClear() + if strings.HasPrefix(r.RemoteHost, "https://ollama.com") { + fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", r.RemoteModel) + } else { + fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", r.RemoteModel, r.RemoteHost) + } + } + return nil + }) } func StopHandler(cmd *cobra.Command, args []string) error { @@ -307,9 +322,10 @@ func RunHandler(cmd *cobra.Command, args []string) error { interactive := true opts := runOptions{ - Model: args[0], - WordWrap: os.Getenv("TERM") == "xterm-256color", - Options: map[string]any{}, + Model: args[0], + WordWrap: os.Getenv("TERM") == "xterm-256color", + Options: map[string]any{}, + ShowConnect: true, } format, err := cmd.Flags().GetString("format") @@ -367,6 +383,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { } prompts = append([]string{string(in)}, prompts...) + opts.ShowConnect = false opts.WordWrap = false interactive = false } @@ -433,6 +450,21 @@ func RunHandler(cmd *cobra.Command, args []string) error { if interactive { if err := loadOrUnloadModel(cmd, &opts); err != nil { + var sErr api.AuthorizationError + if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized { + pubKey, pkErr := auth.GetPublicKey() + if pkErr != nil { + return pkErr + } + // the server and the client both have the same public key + if pubKey == sErr.PublicKey { + h, _ := os.Hostname() + encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) + fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n") + fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey) + } + return nil + } return err } @@ -453,6 +485,56 @@ func RunHandler(cmd *cobra.Command, args []string) error { return generate(cmd, opts) } +func SigninHandler(cmd *cobra.Command, args []string) error { + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + user, err := client.Whoami(cmd.Context()) + if err != nil { + return err + } + + if user != nil && user.Name != "" { + fmt.Printf("You are already signed in as user '%s'\n", user.Name) + fmt.Println() + return nil + } + + pubKey, pkErr := auth.GetPublicKey() + if pkErr != nil { + return pkErr + } + encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) + + h, _ := os.Hostname() + fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey) + + return nil +} + +func SignoutHandler(cmd *cobra.Command, args []string) error { + pubKey, pkErr := auth.GetPublicKey() + if pkErr != nil { + return pkErr + } + encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) + + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + err = client.Signout(cmd.Context(), encKey) + if err != nil { + return err + } + fmt.Println("You have signed out of ollama.com") + fmt.Println() + return nil +} + func PushHandler(cmd *cobra.Command, args []string) error { client, err := api.ClientFromEnvironment() if err != nil { @@ -505,7 +587,8 @@ func PushHandler(cmd *cobra.Command, args []string) error { if spinner != nil { spinner.Stop() } - if strings.Contains(err.Error(), "access denied") { + errStr := strings.ToLower(err.Error()) + if strings.Contains(errStr, "access denied") || strings.Contains(errStr, "unauthorized") { return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own") } return err @@ -539,7 +622,14 @@ func ListHandler(cmd *cobra.Command, args []string) error { for _, m := range models.Models { if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) { - data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")}) + var size string + if m.RemoteModel != "" { + size = "-" + } else { + size = format.HumanBytes(m.Size) + } + + data = append(data, []string{m.Name, m.Digest[:12], size, format.HumanTime(m.ModifiedAt, "Never")}) } } @@ -624,8 +714,8 @@ func DeleteHandler(cmd *cobra.Command, args []string) error { KeepAlive: &api.Duration{Duration: 0}, } if err := loadOrUnloadModel(cmd, opts); err != nil { - if !strings.Contains(err.Error(), "not found") { - return fmt.Errorf("unable to stop existing running model \"%s\": %s", args[0], err) + if !strings.Contains(strings.ToLower(err.Error()), "not found") { + fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0]) } } @@ -736,12 +826,36 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error { } tableRender("Model", func() (rows [][]string) { + if resp.RemoteHost != "" { + rows = append(rows, []string{"", "Remote model", resp.RemoteModel}) + rows = append(rows, []string{"", "Remote URL", resp.RemoteHost}) + } + if resp.ModelInfo != nil { arch := resp.ModelInfo["general.architecture"].(string) rows = append(rows, []string{"", "architecture", arch}) - rows = append(rows, []string{"", "parameters", format.HumanNumber(uint64(resp.ModelInfo["general.parameter_count"].(float64)))}) - rows = append(rows, []string{"", "context length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64), 'f', -1, 64)}) - rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64), 'f', -1, 64)}) + + var paramStr string + if resp.Details.ParameterSize != "" { + paramStr = resp.Details.ParameterSize + } else if v, ok := resp.ModelInfo["general.parameter_count"]; ok { + if f, ok := v.(float64); ok { + paramStr = format.HumanNumber(uint64(f)) + } + } + rows = append(rows, []string{"", "parameters", paramStr}) + + if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok { + if f, ok := v.(float64); ok { + rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)}) + } + } + + if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok { + if f, ok := v.(float64); ok { + rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)}) + } + } } else { rows = append(rows, []string{"", "architecture", resp.Details.Family}) rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize}) @@ -989,6 +1103,7 @@ type runOptions struct { KeepAlive *api.Duration Think *api.ThinkValue HideThinking bool + ShowConnect bool } type displayResponseState struct { @@ -1544,6 +1659,22 @@ func NewCLI() *cobra.Command { pushCmd.Flags().Bool("insecure", false, "Use an insecure registry") + signinCmd := &cobra.Command{ + Use: "signin", + Short: "Sign in to ollama.com", + Args: cobra.ExactArgs(0), + PreRunE: checkServerHeartbeat, + RunE: SigninHandler, + } + + signoutCmd := &cobra.Command{ + Use: "signout", + Short: "Sign out from ollama.com", + Args: cobra.ExactArgs(0), + PreRunE: checkServerHeartbeat, + RunE: SignoutHandler, + } + listCmd := &cobra.Command{ Use: "list", Aliases: []string{"ls"}, @@ -1638,6 +1769,8 @@ func NewCLI() *cobra.Command { stopCmd, pullCmd, pushCmd, + signinCmd, + signoutCmd, listCmd, psCmd, copyCmd, diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index cf5fe7caa4..bb793572fc 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -3,6 +3,7 @@ package cmd import ( "bytes" "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" @@ -304,6 +305,8 @@ func TestDeleteHandler(t *testing.T) { w.WriteHeader(http.StatusOK) } else { w.WriteHeader(http.StatusNotFound) + errPayload := `{"error":"model '%s' not found"}` + w.Write([]byte(fmt.Sprintf(errPayload, req.Name))) } return } @@ -346,7 +349,7 @@ func TestDeleteHandler(t *testing.T) { } err := DeleteHandler(cmd, []string{"test-model-not-found"}) - if err == nil || !strings.Contains(err.Error(), "unable to stop existing running model \"test-model-not-found\"") { + if err == nil || !strings.Contains(err.Error(), "model 'test-model-not-found' not found") { t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err) } } @@ -499,7 +502,7 @@ func TestPushHandler(t *testing.T) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) err := json.NewEncoder(w).Encode(map[string]string{ - "error": "access denied", + "error": "403: {\"errors\":[{\"code\":\"ACCESS DENIED\", \"message\":\"access denied\"}]}", }) if err != nil { t.Fatal(err) @@ -522,6 +525,7 @@ func TestPushHandler(t *testing.T) { defer mockServer.Close() t.Setenv("OLLAMA_HOST", mockServer.URL) + initializeKeypair() cmd := &cobra.Command{} cmd.Flags().Bool("insecure", false, "") diff --git a/envconfig/config.go b/envconfig/config.go index 7fc0188703..09243ab95a 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -134,6 +134,17 @@ func LoadTimeout() (loadTimeout time.Duration) { return loadTimeout } +func Remotes() []string { + var r []string + raw := strings.TrimSpace(Var("OLLAMA_REMOTES")) + if raw == "" { + r = []string{"ollama.com"} + } else { + r = strings.Split(raw, ",") + } + return r +} + func Bool(k string) func() bool { return func() bool { if s := Var(k); s != "" { @@ -270,6 +281,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, + "OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"}, // Informational "HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"}, diff --git a/server/create.go b/server/create.go index f08f18b340..19f24ec805 100644 --- a/server/create.go +++ b/server/create.go @@ -10,8 +10,11 @@ import ( "io" "io/fs" "log/slog" + "net" "net/http" + "net/url" "os" + "path" "path/filepath" "slices" "strings" @@ -39,6 +42,14 @@ var ( ) func (s *Server) CreateHandler(c *gin.Context) { + config := &ConfigV2{ + OS: "linux", + Architecture: "amd64", + RootFS: RootFS{ + Type: "layers", + }, + } + var r api.CreateRequest if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) @@ -48,6 +59,9 @@ func (s *Server) CreateHandler(c *gin.Context) { return } + config.Renderer = r.Renderer + config.Parser = r.Parser + for v := range r.Files { if !fs.ValidPath(v) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()}) @@ -77,20 +91,34 @@ func (s *Server) CreateHandler(c *gin.Context) { oldManifest, _ := ParseNamedManifest(name) var baseLayers []*layerGGML + var err error + var remote bool + if r.From != "" { - slog.Debug("create model from model name") + slog.Debug("create model from model name", "from", r.From) fromName := model.ParseName(r.From) if !fromName.IsValid() { ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest} return } + if r.RemoteHost != "" { + ru, err := remoteURL(r.RemoteHost) + if err != nil { + ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest} + return + } - ctx, cancel := context.WithCancel(c.Request.Context()) - defer cancel() + config.RemoteModel = r.From + config.RemoteHost = ru + remote = true + } else { + ctx, cancel := context.WithCancel(c.Request.Context()) + defer cancel() - baseLayers, err = parseFromModel(ctx, fromName, fn) - if err != nil { - ch <- gin.H{"error": err.Error()} + baseLayers, err = parseFromModel(ctx, fromName, fn) + if err != nil { + ch <- gin.H{"error": err.Error()} + } } } else if r.Files != nil { baseLayers, err = convertModelFromFiles(r.Files, baseLayers, false, fn) @@ -110,7 +138,7 @@ func (s *Server) CreateHandler(c *gin.Context) { } var adapterLayers []*layerGGML - if r.Adapters != nil { + if !remote && r.Adapters != nil { adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn) if err != nil { for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} { @@ -128,7 +156,56 @@ func (s *Server) CreateHandler(c *gin.Context) { baseLayers = append(baseLayers, adapterLayers...) } - if err := createModel(r, name, baseLayers, fn); err != nil { + // Info is not currently exposed by Modelfiles, but allows overriding various + // config values + if r.Info != nil { + caps, ok := r.Info["capabilities"] + if ok { + switch tcaps := caps.(type) { + case []any: + caps := make([]string, len(tcaps)) + for i, c := range tcaps { + str, ok := c.(string) + if !ok { + continue + } + caps[i] = str + } + config.Capabilities = append(config.Capabilities, caps...) + } + } + + strFromInfo := func(k string) string { + v, ok := r.Info[k] + if ok { + val := v.(string) + return val + } + return "" + } + + vFromInfo := func(k string) float64 { + v, ok := r.Info[k] + if ok { + val := v.(float64) + return val + } + return 0 + } + + config.ModelFamily = strFromInfo("model_family") + if config.ModelFamily != "" { + config.ModelFamilies = []string{config.ModelFamily} + } + + config.BaseName = strFromInfo("base_name") + config.FileType = strFromInfo("quantization_level") + config.ModelType = strFromInfo("parameter_size") + config.ContextLen = int(vFromInfo("context_length")) + config.EmbedLen = int(vFromInfo("embedding_length")) + } + + if err := createModel(r, name, baseLayers, config, fn); err != nil { if errors.Is(err, errBadTemplate) { ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest} return @@ -154,6 +231,51 @@ func (s *Server) CreateHandler(c *gin.Context) { streamResponse(c, ch) } +func remoteURL(raw string) (string, error) { + // Special‑case: user supplied only a path ("/foo/bar"). + if strings.HasPrefix(raw, "/") { + return (&url.URL{ + Scheme: "http", + Host: net.JoinHostPort("localhost", "11434"), + Path: path.Clean(raw), + }).String(), nil + } + + if !strings.Contains(raw, "://") { + raw = "http://" + raw + } + + if raw == "ollama.com" || raw == "http://ollama.com" { + raw = "https://ollama.com:443" + } + + u, err := url.Parse(raw) + if err != nil { + return "", fmt.Errorf("parse error: %w", err) + } + + if u.Host == "" { + u.Host = "localhost" + } + + hostPart, portPart, err := net.SplitHostPort(u.Host) + if err == nil { + u.Host = net.JoinHostPort(hostPart, portPart) + } else { + u.Host = net.JoinHostPort(u.Host, "11434") + } + + if u.Path != "" { + u.Path = path.Clean(u.Path) + } + + if u.Path == "/" { + u.Path = "" + } + + return u.String(), nil +} + func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) { switch detectModelTypeFromFiles(files) { case "safetensors": @@ -316,17 +438,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) { return ggml.KV{}, fmt.Errorf("no base model was found") } -func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, fn func(resp api.ProgressResponse)) (err error) { - config := ConfigV2{ - OS: "linux", - Architecture: "amd64", - RootFS: RootFS{ - Type: "layers", - }, - Renderer: r.Renderer, - Parser: r.Parser, - } - +func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) { var layers []Layer for _, layer := range baseLayers { if layer.GGML != nil { @@ -406,7 +518,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, return err } - configLayer, err := createConfigLayer(layers, config) + configLayer, err := createConfigLayer(layers, *config) if err != nil { return err } diff --git a/server/create_test.go b/server/create_test.go index 59a07ff148..061efb81aa 100644 --- a/server/create_test.go +++ b/server/create_test.go @@ -104,3 +104,154 @@ func TestConvertFromSafetensors(t *testing.T) { }) } } + +func TestRemoteURL(t *testing.T) { + tests := []struct { + name string + input string + expected string + hasError bool + }{ + { + name: "absolute path", + input: "/foo/bar", + expected: "http://localhost:11434/foo/bar", + hasError: false, + }, + { + name: "absolute path with cleanup", + input: "/foo/../bar", + expected: "http://localhost:11434/bar", + hasError: false, + }, + { + name: "root path", + input: "/", + expected: "http://localhost:11434/", + hasError: false, + }, + { + name: "host without scheme", + input: "example.com", + expected: "http://example.com:11434", + hasError: false, + }, + { + name: "host with port", + input: "example.com:8080", + expected: "http://example.com:8080", + hasError: false, + }, + { + name: "full URL", + input: "https://example.com:8080/path", + expected: "https://example.com:8080/path", + hasError: false, + }, + { + name: "full URL with path cleanup", + input: "https://example.com:8080/path/../other", + expected: "https://example.com:8080/other", + hasError: false, + }, + { + name: "ollama.com special case", + input: "ollama.com", + expected: "https://ollama.com:443", + hasError: false, + }, + { + name: "http ollama.com special case", + input: "http://ollama.com", + expected: "https://ollama.com:443", + hasError: false, + }, + { + name: "URL with only host", + input: "http://example.com", + expected: "http://example.com:11434", + hasError: false, + }, + { + name: "URL with root path cleaned", + input: "http://example.com/", + expected: "http://example.com:11434", + hasError: false, + }, + { + name: "invalid URL", + input: "http://[::1]:namedport", // invalid port + expected: "", + hasError: true, + }, + { + name: "empty string", + input: "", + expected: "http://localhost:11434", + hasError: false, + }, + { + name: "host with scheme but no port", + input: "http://localhost", + expected: "http://localhost:11434", + hasError: false, + }, + { + name: "complex path cleanup", + input: "/a/b/../../c/./d", + expected: "http://localhost:11434/c/d", + hasError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := remoteURL(tt.input) + + if tt.hasError { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if result != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, result) + } + }) + } +} + +func TestRemoteURL_Idempotent(t *testing.T) { + // Test that applying remoteURL twice gives the same result as applying it once + testInputs := []string{ + "/foo/bar", + "example.com", + "https://example.com:8080/path", + "ollama.com", + "http://localhost:11434", + } + + for _, input := range testInputs { + t.Run(input, func(t *testing.T) { + firstResult, err := remoteURL(input) + if err != nil { + t.Fatalf("first call failed: %v", err) + } + + secondResult, err := remoteURL(firstResult) + if err != nil { + t.Fatalf("second call failed: %v", err) + } + + if firstResult != secondResult { + t.Errorf("function is not idempotent: first=%q, second=%q", firstResult, secondResult) + } + }) + } +} diff --git a/server/images.go b/server/images.go index 6432860f8e..9466b7fb47 100644 --- a/server/images.go +++ b/server/images.go @@ -74,21 +74,29 @@ func (m *Model) Capabilities() []model.Capability { capabilities := []model.Capability{} // Check for completion capability - f, err := gguf.Open(m.ModelPath) - if err == nil { - defer f.Close() + if m.ModelPath != "" { + f, err := gguf.Open(m.ModelPath) + if err == nil { + defer f.Close() - if f.KeyValue("pooling_type").Valid() { - capabilities = append(capabilities, model.CapabilityEmbedding) + if f.KeyValue("pooling_type").Valid() { + capabilities = append(capabilities, model.CapabilityEmbedding) + } else { + // If no embedding is specified, we assume the model supports completion + capabilities = append(capabilities, model.CapabilityCompletion) + } + if f.KeyValue("vision.block_count").Valid() { + capabilities = append(capabilities, model.CapabilityVision) + } } else { - // If no embedding is specified, we assume the model supports completion - capabilities = append(capabilities, model.CapabilityCompletion) + slog.Error("couldn't open model file", "error", err) } - if f.KeyValue("vision.block_count").Valid() { - capabilities = append(capabilities, model.CapabilityVision) + } else if len(m.Config.Capabilities) > 0 { + for _, c := range m.Config.Capabilities { + capabilities = append(capabilities, model.Capability(c)) } } else { - slog.Error("couldn't open model file", "error", err) + slog.Warn("unknown capabilities for model", "model", m.Name) } if m.Template == nil { @@ -111,6 +119,11 @@ func (m *Model) Capabilities() []model.Capability { capabilities = append(capabilities, model.CapabilityVision) } + // Skip the thinking check if it's already set + if slices.Contains(capabilities, "thinking") { + return capabilities + } + // Check for thinking capability openingTag, closingTag := thinking.InferTags(m.Template.Template) hasTags := openingTag != "" && closingTag != "" @@ -253,11 +266,20 @@ type ConfigV2 struct { ModelFormat string `json:"model_format"` ModelFamily string `json:"model_family"` ModelFamilies []string `json:"model_families"` - ModelType string `json:"model_type"` - FileType string `json:"file_type"` + ModelType string `json:"model_type"` // shown as Parameter Size + FileType string `json:"file_type"` // shown as Quantization Level Renderer string `json:"renderer,omitempty"` Parser string `json:"parser,omitempty"` + RemoteHost string `json:"remote_host,omitempty"` + RemoteModel string `json:"remote_model,omitempty"` + + // used for remotes + Capabilities []string `json:"capabilities,omitempty"` + ContextLen int `json:"context_length,omitempty"` + EmbedLen int `json:"embedding_length,omitempty"` + BaseName string `json:"base_name,omitempty"` + // required by spec Architecture string `json:"architecture"` OS string `json:"os"` diff --git a/server/routes.go b/server/routes.go index e999c6c01e..dc868038ce 100644 --- a/server/routes.go +++ b/server/routes.go @@ -15,6 +15,7 @@ import ( "net" "net/http" "net/netip" + "net/url" "os" "os/signal" "slices" @@ -28,6 +29,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/auth" "github.com/ollama/ollama/discover" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" @@ -189,6 +191,84 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { + origModel := req.Model + + remoteURL, err := url.Parse(m.Config.RemoteHost) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) { + slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname()) + c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"}) + return + } + + req.Model = m.Config.RemoteModel + + if req.Template == "" && m.Template.String() != "" { + req.Template = m.Template.String() + } + + if req.Options == nil { + req.Options = map[string]any{} + } + + for k, v := range m.Options { + if _, ok := req.Options[k]; !ok { + req.Options[k] = v + } + } + + // update the system prompt from the model if one isn't already specified + if req.System == "" && m.System != "" { + req.System = m.System + } + + if len(m.Messages) > 0 { + slog.Warn("embedded messages in the model not supported with '/api/generate'; try '/api/chat' instead") + } + + fn := func(resp api.GenerateResponse) error { + resp.Model = origModel + resp.RemoteModel = m.Config.RemoteModel + resp.RemoteHost = m.Config.RemoteHost + + data, err := json.Marshal(resp) + if err != nil { + return err + } + + if _, err = c.Writer.Write(append(data, '\n')); err != nil { + return err + } + c.Writer.Flush() + return nil + } + + client := api.NewClient(remoteURL, http.DefaultClient) + err = client.Generate(c, &req, fn) + if err != nil { + var sErr api.AuthorizationError + if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized { + pk, pkErr := auth.GetPublicKey() + if pkErr != nil { + slog.Error("couldn't get public key", "error", pkErr) + c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"}) + return + } + c.JSON(http.StatusUnauthorized, gin.H{"public_key": pk}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + return + } + // expire the runner if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 { s.sched.expireRunner(m) @@ -931,6 +1011,28 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { ModifiedAt: manifest.fi.ModTime(), } + if m.Config.RemoteHost != "" { + resp.RemoteHost = m.Config.RemoteHost + resp.RemoteModel = m.Config.RemoteModel + + if m.Config.ModelFamily != "" { + resp.ModelInfo = make(map[string]any) + resp.ModelInfo["general.architecture"] = m.Config.ModelFamily + + if m.Config.BaseName != "" { + resp.ModelInfo["general.basename"] = m.Config.BaseName + } + + if m.Config.ContextLen > 0 { + resp.ModelInfo[fmt.Sprintf("%s.context_length", m.Config.ModelFamily)] = m.Config.ContextLen + } + + if m.Config.EmbedLen > 0 { + resp.ModelInfo[fmt.Sprintf("%s.embedding_length", m.Config.ModelFamily)] = m.Config.EmbedLen + } + } + } + var params []string cs := 30 for k, v := range m.Options { @@ -961,6 +1063,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { fmt.Fprint(&sb, m.String()) resp.Modelfile = sb.String() + // skip loading tensor information if this is a remote model + if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { + return resp, nil + } + kvData, tensors, err := getModelData(m.ModelPath, req.Verbose) if err != nil { return nil, err @@ -1037,11 +1144,13 @@ func (s *Server) ListHandler(c *gin.Context) { // tag should never be masked models = append(models, api.ListModelResponse{ - Model: n.DisplayShortest(), - Name: n.DisplayShortest(), - Size: m.Size(), - Digest: m.digest, - ModifiedAt: m.fi.ModTime(), + Model: n.DisplayShortest(), + Name: n.DisplayShortest(), + RemoteModel: cf.RemoteModel, + RemoteHost: cf.RemoteHost, + Size: m.Size(), + Digest: m.digest, + ModifiedAt: m.fi.ModTime(), Details: api.ModelDetails{ Format: cf.ModelFormat, Family: cf.ModelFamily, @@ -1301,6 +1410,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.POST("/api/show", s.ShowHandler) r.DELETE("/api/delete", s.DeleteHandler) + r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler) + r.POST("/api/me", s.WhoamiHandler) + // Create r.POST("/api/create", s.CreateHandler) r.POST("/api/blobs/:digest", s.CreateBlobHandler) @@ -1497,6 +1609,49 @@ func streamResponse(c *gin.Context, ch chan any) { }) } +func (s *Server) WhoamiHandler(c *gin.Context) { + // todo allow other hosts + u, err := url.Parse("https://ollama.com") + if err != nil { + slog.Error(err.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"}) + return + } + + client := api.NewClient(u, http.DefaultClient) + user, err := client.Whoami(c) + if err != nil { + slog.Error(err.Error()) + } + c.JSON(http.StatusOK, user) +} + +func (s *Server) SignoutHandler(c *gin.Context) { + encodedKey := c.Param("encodedKey") + + // todo allow other hosts + u, err := url.Parse("https://ollama.com") + if err != nil { + slog.Error(err.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"}) + return + } + + client := api.NewClient(u, http.DefaultClient) + err = client.Signout(c, encodedKey) + if err != nil { + slog.Error(err.Error()) + if strings.Contains(err.Error(), "page not found") || strings.Contains(err.Error(), "invalid credentials") { + c.JSON(http.StatusNotFound, gin.H{"error": "you are not currently signed in"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"}) + return + } + + c.JSON(http.StatusOK, nil) +} + func (s *Server) PsHandler(c *gin.Context) { models := []api.ProcessModelResponse{} @@ -1553,21 +1708,34 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - // expire the runner - if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 { - model, err := GetModel(req.Model) - if err != nil { - switch { - case os.IsNotExist(err): - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) - case err.Error() == errtypes.InvalidModelNameErrMsg: - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - default: - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - } - return + name := model.ParseName(req.Model) + if !name.IsValid() { + c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + return + } + + name, err := getExistingName(name) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + return + } + + m, err := GetModel(req.Model) + if err != nil { + switch { + case os.IsNotExist(err): + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) + case err.Error() == errtypes.InvalidModelNameErrMsg: + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } - s.sched.expireRunner(model) + return + } + + // expire the runner + if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 { + s.sched.expireRunner(m) c.JSON(http.StatusOK, api.ChatResponse{ Model: req.Model, @@ -1579,6 +1747,66 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { + origModel := req.Model + + remoteURL, err := url.Parse(m.Config.RemoteHost) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) { + slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname()) + c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"}) + return + } + + req.Model = m.Config.RemoteModel + if req.Options == nil { + req.Options = map[string]any{} + } + + msgs := append(m.Messages, req.Messages...) + if req.Messages[0].Role != "system" && m.System != "" { + msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...) + } + msgs = filterThinkTags(msgs, m) + req.Messages = msgs + + for k, v := range m.Options { + if _, ok := req.Options[k]; !ok { + req.Options[k] = v + } + } + + fn := func(resp api.ChatResponse) error { + resp.Model = origModel + resp.RemoteModel = m.Config.RemoteModel + resp.RemoteHost = m.Config.RemoteHost + + data, err := json.Marshal(resp) + if err != nil { + return err + } + + if _, err = c.Writer.Write(append(data, '\n')); err != nil { + return err + } + c.Writer.Flush() + return nil + } + + client := api.NewClient(remoteURL, http.DefaultClient) + err = client.Chat(c, &req, fn) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + return + } + caps := []model.Capability{model.CapabilityCompletion} if len(req.Tools) > 0 { caps = append(caps, model.CapabilityTools) @@ -1587,17 +1815,6 @@ func (s *Server) ChatHandler(c *gin.Context) { caps = append(caps, model.CapabilityThinking) } - name := model.ParseName(req.Model) - if !name.IsValid() { - c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) - return - } - name, err := getExistingName(name) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) - return - } - r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 3b3d99100d..189ef04070 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -11,6 +11,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "reflect" "slices" "strings" "testing" @@ -20,6 +21,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/types/model" ) var stream bool = false @@ -615,6 +617,78 @@ func TestCreateTemplateSystem(t *testing.T) { }) } +func TestCreateAndShowRemoteModel(t *testing.T) { + gin.SetMode(gin.TestMode) + + var s Server + + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "test", + From: "bob", + RemoteHost: "https://ollama.com", + Info: map[string]any{ + "capabilities": []string{"completion", "tools", "thinking"}, + "model_family": "gptoss", + "context_length": 131072, + "embedding_length": 2880, + "quantization_level": "MXFP4", + "parameter_size": "20.9B", + }, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("exected status code 200, actual %d", w.Code) + } + + w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test"}) + if w.Code != http.StatusOK { + t.Fatalf("exected status code 200, actual %d", w.Code) + } + + var resp api.ShowResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + + expectedDetails := api.ModelDetails{ + ParentModel: "", + Format: "", + Family: "gptoss", + Families: []string{"gptoss"}, + ParameterSize: "20.9B", + QuantizationLevel: "MXFP4", + } + + if !reflect.DeepEqual(resp.Details, expectedDetails) { + t.Errorf("model details: expected %#v, actual %#v", expectedDetails, resp.Details) + } + + expectedCaps := []model.Capability{ + model.Capability("completion"), + model.Capability("tools"), + model.Capability("thinking"), + } + + if !slices.Equal(resp.Capabilities, expectedCaps) { + t.Errorf("capabilities: expected %#v, actual %#v", expectedCaps, resp.Capabilities) + } + + v, ok := resp.ModelInfo["gptoss.context_length"] + ctxlen := v.(float64) + if !ok || int(ctxlen) != 131072 { + t.Errorf("context len: expected %d, actual %d", 131072, int(ctxlen)) + } + + v, ok = resp.ModelInfo["gptoss.embedding_length"] + embedlen := v.(float64) + if !ok || int(embedlen) != 2880 { + t.Errorf("embed len: expected %d, actual %d", 2880, int(embedlen)) + } + + fmt.Printf("resp = %#v\n", resp) +} + func TestCreateLicenses(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/server/routes_test.go b/server/routes_test.go index 87b5266337..bb7e2b7c12 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -126,7 +126,15 @@ func TestRoutes(t *testing.T) { t.Fatalf("failed to create model: %v", err) } - if err := createModel(r, modelName, baseLayers, fn); err != nil { + config := &ConfigV2{ + OS: "linux", + Architecture: "amd64", + RootFS: RootFS{ + Type: "layers", + }, + } + + if err := createModel(r, modelName, baseLayers, config, fn); err != nil { t.Fatal(err) } }