diff --git a/api/client.go b/api/client.go index 4688d4d13..bdc449e13 100644 --- a/api/client.go +++ b/api/client.go @@ -163,24 +163,29 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f scanBuf := make([]byte, 0, maxBufferSize) scanner.Buffer(scanBuf, maxBufferSize) for scanner.Scan() { - var errorResponse struct { - Error string `json:"error,omitempty"` - } - bts := scanner.Bytes() + + var errorResponse ErrorResponse if err := json.Unmarshal(bts, &errorResponse); err != nil { return fmt.Errorf("unmarshal: %w", err) } - if errorResponse.Error != "" { - return errors.New(errorResponse.Error) + switch errorResponse.Code { + case ErrCodeUnknownKey: + return ErrUnknownOllamaKey{ + Message: errorResponse.Message, + Key: errorResponse.Data["key"].(string), + } + } + if errorResponse.Message != "" { + return errors.New(errorResponse.Message) } if response.StatusCode >= http.StatusBadRequest { return StatusError{ StatusCode: response.StatusCode, Status: response.Status, - ErrorMessage: errorResponse.Error, + ErrorMessage: errorResponse.Message, } } diff --git a/api/client_test.go b/api/client_test.go index 23fe9334b..6c3b6bd3f 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -1,6 +1,12 @@ package api import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "net/url" "testing" ) @@ -43,3 +49,117 @@ func TestClientFromEnvironment(t *testing.T) { }) } } + +func TestStream(t *testing.T) { + tests := []struct { + name string + serverResponse []string + statusCode int + expectedError error + }{ + { + name: "unknown key error", + serverResponse: []string{ + `{"error":"unauthorized access","code":"unknown_key","data":{"key":"test-key"}}`, + }, + statusCode: http.StatusUnauthorized, + expectedError: &ErrUnknownOllamaKey{ + Message: "unauthorized access", + Key: "test-key", + }, + }, + { + name: "general error message", + serverResponse: []string{ + `{"error":"something went wrong"}`, + }, + statusCode: http.StatusInternalServerError, + expectedError: fmt.Errorf("something went wrong"), + }, + { + name: "malformed json response", + serverResponse: []string{ + `{invalid-json`, + }, + statusCode: http.StatusOK, + expectedError: fmt.Errorf("unmarshal: invalid character 'i' looking for beginning of object key string"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/x-ndjson") + w.WriteHeader(tt.statusCode) + for _, resp := range tt.serverResponse { + fmt.Fprintln(w, resp) + } + })) + defer server.Close() + + baseURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("failed to parse server URL: %v", err) + } + + client := &Client{ + http: server.Client(), + base: baseURL, + } + + var responses [][]byte + err = client.stream(context.Background(), "POST", "/test", "test", func(bts []byte) error { + responses = append(responses, bts) + return nil + }) + + // Error checking + if tt.expectedError == nil { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + return + } + + if err == nil { + t.Fatalf("expected error %v, got nil", tt.expectedError) + } + + // Check for specific error types + var unknownKeyErr ErrUnknownOllamaKey + if errors.As(tt.expectedError, &unknownKeyErr) { + var gotErr ErrUnknownOllamaKey + if !errors.As(err, &gotErr) { + t.Fatalf("expected ErrUnknownOllamaKey, got %T", err) + } + if unknownKeyErr.Key != gotErr.Key { + t.Errorf("expected key %q, got %q", unknownKeyErr.Key, gotErr.Key) + } + if unknownKeyErr.Message != gotErr.Message { + t.Errorf("expected message %q, got %q", unknownKeyErr.Message, gotErr.Message) + } + return + } + + var statusErr StatusError + if errors.As(tt.expectedError, &statusErr) { + var gotErr StatusError + if !errors.As(err, &gotErr) { + t.Fatalf("expected StatusError, got %T", err) + } + if statusErr.StatusCode != gotErr.StatusCode { + t.Errorf("expected status code %d, got %d", statusErr.StatusCode, gotErr.StatusCode) + } + if statusErr.ErrorMessage != gotErr.ErrorMessage { + t.Errorf("expected error message %q, got %q", statusErr.ErrorMessage, gotErr.ErrorMessage) + } + return + } + + // For other errors, compare error strings + if err.Error() != tt.expectedError.Error() { + t.Errorf("expected error %q, got %q", tt.expectedError, err) + } + }) + } +} diff --git a/api/errors.go b/api/errors.go new file mode 100644 index 000000000..54ad6c918 --- /dev/null +++ b/api/errors.go @@ -0,0 +1,74 @@ +package api + +import ( + "fmt" + "slices" + "strings" +) + +const InvalidModelNameErrMsg = "invalid model name" + +// API error responses +// ErrorCode represents a standardized error code identifier +type ErrorCode string + +const ( + ErrCodeUnknownKey ErrorCode = "unknown_key" + ErrCodeGeneral ErrorCode = "general" // Generic fallback error code +) + +// ErrorResponse implements a structured error interface +type ErrorResponse struct { + Message string `json:"error"` // Human-readable error message, uses 'error' field name for backwards compatibility + Code ErrorCode `json:"code"` // Machine-readable error code for programmatic handling, not response code + Data map[string]any `json:"data"` // Additional error specific data, if any +} + +func (e ErrorResponse) Error() string { + return e.Message +} + +type ErrUnknownOllamaKey struct { + Message string + Key string +} + +func (e ErrUnknownOllamaKey) Error() string { + return fmt.Sprintf("unauthorized: unknown ollama key %q", strings.TrimSpace(e.Key)) +} + +func (e *ErrUnknownOllamaKey) FormatUserMessage(localKeys []string) string { + // The user should only be told to add the key if it is the same one that exists locally + if slices.Index(localKeys, e.Key) == -1 { + return e.Message + } + + return fmt.Sprintf(`%s + +Your ollama key is: +%s +Add your key at: +https://ollama.com/settings/keys`, e.Message, e.Key) +} + +// StatusError is an error with an HTTP status code and message, +// it is parsed on the client-side and not returned from the API +type StatusError struct { + StatusCode int // e.g. 200 + Status string // e.g. "200 OK" + ErrorMessage string `json:"error"` +} + +func (e StatusError) Error() string { + switch { + case e.Status != "" && e.ErrorMessage != "": + return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage) + case e.Status != "": + return e.Status + case e.ErrorMessage != "": + return e.ErrorMessage + default: + // this should not happen + return "something went wrong, please see the ollama server logs for details" + } +} diff --git a/api/types.go b/api/types.go index e5291a024..6728cb743 100644 --- a/api/types.go +++ b/api/types.go @@ -12,27 +12,6 @@ import ( "time" ) -// StatusError is an error with an HTTP status code and message. -type StatusError struct { - StatusCode int - Status string - ErrorMessage string `json:"error"` -} - -func (e StatusError) Error() string { - switch { - case e.Status != "" && e.ErrorMessage != "": - return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage) - case e.Status != "": - return e.Status - case e.ErrorMessage != "": - return e.ErrorMessage - default: - // this should not happen - return "something went wrong, please see the ollama server logs for details" - } -} - // ImageData represents the raw binary data of an image file. type ImageData []byte diff --git a/cmd/cmd.go b/cmd/cmd.go index ce7187aee..9a8d95a4d 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -19,7 +19,6 @@ import ( "os" "os/signal" "path/filepath" - "regexp" "runtime" "strconv" "strings" @@ -41,7 +40,6 @@ import ( "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/server" - "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) @@ -516,46 +514,22 @@ func RunHandler(cmd *cobra.Command, args []string) error { return generate(cmd, opts) } -func errFromUnknownKey(unknownKeyErr error) error { - // find SSH public key in the error message - // TODO (brucemacd): the API should return structured errors so that this message parsing isn't needed - sshKeyPattern := `ssh-\w+ [^\s"]+` - re := regexp.MustCompile(sshKeyPattern) - matches := re.FindStringSubmatch(unknownKeyErr.Error()) - - if len(matches) > 0 { - serverPubKey := matches[0] - - localPubKey, err := auth.GetPublicKey() - if err != nil { - return unknownKeyErr - } - - if runtime.GOOS == "linux" && serverPubKey != localPubKey { - // try the ollama service public key - svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub") - if err != nil { - return unknownKeyErr - } - localPubKey = strings.TrimSpace(string(svcPubKey)) - } - - // check if the returned public key matches the local public key, this prevents adding a remote key to the user's account - if serverPubKey != localPubKey { - return unknownKeyErr - } - - var msg strings.Builder - msg.WriteString(unknownKeyErr.Error()) - msg.WriteString("\n\nYour ollama key is:\n") - msg.WriteString(localPubKey) - msg.WriteString("\nAdd your key at:\n") - msg.WriteString("https://ollama.com/settings/keys") - - return errors.New(msg.String()) +func localPubKeys() ([]string, error) { + usrKey, err := auth.GetPublicKey() + if err != nil { + return nil, err } - return unknownKeyErr + keys := []string{usrKey} + + if runtime.GOOS == "linux" { + // try the ollama service public key if on Linux + if svcKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub"); err == nil { + keys = append(keys, strings.TrimSpace(string(svcKey))) + } + } + + return keys, nil } func PushHandler(cmd *cobra.Command, args []string) error { @@ -611,15 +585,18 @@ func PushHandler(cmd *cobra.Command, args []string) error { if spinner != nil { spinner.Stop() } + var ke api.ErrUnknownOllamaKey + if errors.As(err, &ke) && isOllamaHost { + + // the user has not added their ollama key to ollama.com + // return an error with a more user-friendly message + locals, _ := localPubKeys() + return errors.New(ke.FormatUserMessage(locals)) + } if strings.Contains(err.Error(), "access denied") { return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own") } - if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost { - // the user has not added their ollama key to ollama.com - // return an error with a more user-friendly message - return errFromUnknownKey(err) - } - return err + return fmt.Errorf("yoyoyo: %w", err) } p.Stop() diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 3a8e44a7e..6a1cd61e5 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -16,7 +16,6 @@ import ( "github.com/spf13/cobra" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/types/errtypes" ) func TestShowInfo(t *testing.T) { @@ -437,7 +436,7 @@ func TestPushHandler(t *testing.T) { "/api/push": func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) - uerr := errtypes.UnknownOllamaKey{ + uerr := api.ErrUnknownOllamaKey{ Key: "aaa", } err := json.NewEncoder(w).Encode(map[string]string{ diff --git a/cmd/interactive.go b/cmd/interactive.go index 9035b4c52..0f214c560 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -19,7 +19,6 @@ import ( "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/readline" - "github.com/ollama/ollama/types/errtypes" ) type MultilineState int @@ -220,7 +219,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { fn := func(resp api.ProgressResponse) error { return nil } err = client.Create(cmd.Context(), req, fn) if err != nil { - if strings.Contains(err.Error(), errtypes.InvalidModelNameErrMsg) { + if strings.Contains(err.Error(), api.InvalidModelNameErrMsg) { fmt.Printf("error: The model name '%s' is invalid\n", args[1]) continue } diff --git a/server/images.go b/server/images.go index cda8eb317..08c86da2f 100644 --- a/server/images.go +++ b/server/images.go @@ -30,7 +30,6 @@ import ( "github.com/ollama/ollama/llm" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/template" - "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/registry" "github.com/ollama/ollama/version" @@ -1031,7 +1030,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr)) return nil, re } - return nil, errtypes.UnknownOllamaKey{ + return nil, api.ErrUnknownOllamaKey{ Key: pubKey, } } diff --git a/server/routes.go b/server/routes.go index c13cd023f..848afe6e0 100644 --- a/server/routes.go +++ b/server/routes.go @@ -36,7 +36,6 @@ import ( "github.com/ollama/ollama/runners" "github.com/ollama/ollama/server/imageproc" "github.com/ollama/ollama/template" - "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) @@ -610,7 +609,7 @@ func (s *Server) PushHandler(c *gin.Context) { defer cancel() if err := PushModel(ctx, model, regOpts, fn); err != nil { - ch <- gin.H{"error": err.Error()} + ch <- newErr(err) } }() @@ -650,7 +649,7 @@ func (s *Server) CreateHandler(c *gin.Context) { name := model.ParseName(cmp.Or(r.Model, r.Name)) if !name.IsValid() { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg}) + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": api.InvalidModelNameErrMsg}) return } @@ -1550,3 +1549,24 @@ func handleScheduleError(c *gin.Context, name string, err error) { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } } + +// newErr creates a structured API ErrorResponse from an existing error +func newErr(err error) api.ErrorResponse { + if err == nil { + return api.ErrorResponse{} + } + // Default to just returning the generic error message + resp := api.ErrorResponse{ + Code: api.ErrCodeGeneral, + Message: err.Error(), + } + // Add additional error specific data, if any + var errResp api.ErrUnknownOllamaKey + if errors.As(err, &errResp) { + resp.Code = api.ErrCodeUnknownKey + resp.Data = map[string]any{ + "key": errResp.Key, + } + } + return resp +} diff --git a/types/errtypes/errtypes.go b/types/errtypes/errtypes.go deleted file mode 100644 index 814b58b03..000000000 --- a/types/errtypes/errtypes.go +++ /dev/null @@ -1,21 +0,0 @@ -// Package errtypes contains custom error types -package errtypes - -import ( - "fmt" - "strings" -) - -const ( - UnknownOllamaKeyErrMsg = "unknown ollama key" - InvalidModelNameErrMsg = "invalid model name" -) - -// TODO: This should have a structured response from the API -type UnknownOllamaKey struct { - Key string -} - -func (e UnknownOllamaKey) Error() string { - return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key)) -}