mirror of
https://github.com/ollama/ollama.git
synced 2025-04-11 05:09:45 +02:00
api: return structured error on unauthorized push
This commit implements a structured error response system for the Ollama API, replacing ad-hoc error handling and string parsing with proper error types and codes through a new ErrorResponse struct. Instead of relying on regex to parse error messages for SSH keys, the API now passes this data in a structured format with standardized fields for error messages, codes, and additional data. This structured approach makes the API more maintainable and reliable while improving the developer experience by enabling programmatic error handling, consistent error formats, and better error documentation.
This commit is contained in:
parent
ae9165d661
commit
9e190ac4d9
@ -163,24 +163,29 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
scanBuf := make([]byte, 0, maxBufferSize)
|
||||
scanner.Buffer(scanBuf, maxBufferSize)
|
||||
for scanner.Scan() {
|
||||
var errorResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
bts := scanner.Bytes()
|
||||
|
||||
var errorResponse ErrorResponse
|
||||
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
||||
return fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
|
||||
if errorResponse.Error != "" {
|
||||
return errors.New(errorResponse.Error)
|
||||
switch errorResponse.Code {
|
||||
case ErrCodeUnknownKey:
|
||||
return ErrUnknownOllamaKey{
|
||||
Message: errorResponse.Message,
|
||||
Key: errorResponse.Data["key"].(string),
|
||||
}
|
||||
}
|
||||
if errorResponse.Message != "" {
|
||||
return errors.New(errorResponse.Message)
|
||||
}
|
||||
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
return StatusError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
ErrorMessage: errorResponse.Error,
|
||||
ErrorMessage: errorResponse.Message,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,12 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -43,3 +49,117 @@ func TestClientFromEnvironment(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverResponse []string
|
||||
statusCode int
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "unknown key error",
|
||||
serverResponse: []string{
|
||||
`{"error":"unauthorized access","code":"unknown_key","data":{"key":"test-key"}}`,
|
||||
},
|
||||
statusCode: http.StatusUnauthorized,
|
||||
expectedError: &ErrUnknownOllamaKey{
|
||||
Message: "unauthorized access",
|
||||
Key: "test-key",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "general error message",
|
||||
serverResponse: []string{
|
||||
`{"error":"something went wrong"}`,
|
||||
},
|
||||
statusCode: http.StatusInternalServerError,
|
||||
expectedError: fmt.Errorf("something went wrong"),
|
||||
},
|
||||
{
|
||||
name: "malformed json response",
|
||||
serverResponse: []string{
|
||||
`{invalid-json`,
|
||||
},
|
||||
statusCode: http.StatusOK,
|
||||
expectedError: fmt.Errorf("unmarshal: invalid character 'i' looking for beginning of object key string"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.WriteHeader(tt.statusCode)
|
||||
for _, resp := range tt.serverResponse {
|
||||
fmt.Fprintln(w, resp)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
baseURL, err := url.Parse(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse server URL: %v", err)
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
http: server.Client(),
|
||||
base: baseURL,
|
||||
}
|
||||
|
||||
var responses [][]byte
|
||||
err = client.stream(context.Background(), "POST", "/test", "test", func(bts []byte) error {
|
||||
responses = append(responses, bts)
|
||||
return nil
|
||||
})
|
||||
|
||||
// Error checking
|
||||
if tt.expectedError == nil {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("expected error %v, got nil", tt.expectedError)
|
||||
}
|
||||
|
||||
// Check for specific error types
|
||||
var unknownKeyErr ErrUnknownOllamaKey
|
||||
if errors.As(tt.expectedError, &unknownKeyErr) {
|
||||
var gotErr ErrUnknownOllamaKey
|
||||
if !errors.As(err, &gotErr) {
|
||||
t.Fatalf("expected ErrUnknownOllamaKey, got %T", err)
|
||||
}
|
||||
if unknownKeyErr.Key != gotErr.Key {
|
||||
t.Errorf("expected key %q, got %q", unknownKeyErr.Key, gotErr.Key)
|
||||
}
|
||||
if unknownKeyErr.Message != gotErr.Message {
|
||||
t.Errorf("expected message %q, got %q", unknownKeyErr.Message, gotErr.Message)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var statusErr StatusError
|
||||
if errors.As(tt.expectedError, &statusErr) {
|
||||
var gotErr StatusError
|
||||
if !errors.As(err, &gotErr) {
|
||||
t.Fatalf("expected StatusError, got %T", err)
|
||||
}
|
||||
if statusErr.StatusCode != gotErr.StatusCode {
|
||||
t.Errorf("expected status code %d, got %d", statusErr.StatusCode, gotErr.StatusCode)
|
||||
}
|
||||
if statusErr.ErrorMessage != gotErr.ErrorMessage {
|
||||
t.Errorf("expected error message %q, got %q", statusErr.ErrorMessage, gotErr.ErrorMessage)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// For other errors, compare error strings
|
||||
if err.Error() != tt.expectedError.Error() {
|
||||
t.Errorf("expected error %q, got %q", tt.expectedError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
74
api/errors.go
Normal file
74
api/errors.go
Normal file
@ -0,0 +1,74 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const InvalidModelNameErrMsg = "invalid model name"
|
||||
|
||||
// API error responses
|
||||
// ErrorCode represents a standardized error code identifier
|
||||
type ErrorCode string
|
||||
|
||||
const (
|
||||
ErrCodeUnknownKey ErrorCode = "unknown_key"
|
||||
ErrCodeGeneral ErrorCode = "general" // Generic fallback error code
|
||||
)
|
||||
|
||||
// ErrorResponse implements a structured error interface
|
||||
type ErrorResponse struct {
|
||||
Message string `json:"error"` // Human-readable error message, uses 'error' field name for backwards compatibility
|
||||
Code ErrorCode `json:"code"` // Machine-readable error code for programmatic handling, not response code
|
||||
Data map[string]any `json:"data"` // Additional error specific data, if any
|
||||
}
|
||||
|
||||
func (e ErrorResponse) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
type ErrUnknownOllamaKey struct {
|
||||
Message string
|
||||
Key string
|
||||
}
|
||||
|
||||
func (e ErrUnknownOllamaKey) Error() string {
|
||||
return fmt.Sprintf("unauthorized: unknown ollama key %q", strings.TrimSpace(e.Key))
|
||||
}
|
||||
|
||||
func (e *ErrUnknownOllamaKey) FormatUserMessage(localKeys []string) string {
|
||||
// The user should only be told to add the key if it is the same one that exists locally
|
||||
if slices.Index(localKeys, e.Key) == -1 {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`%s
|
||||
|
||||
Your ollama key is:
|
||||
%s
|
||||
Add your key at:
|
||||
https://ollama.com/settings/keys`, e.Message, e.Key)
|
||||
}
|
||||
|
||||
// StatusError is an error with an HTTP status code and message,
|
||||
// it is parsed on the client-side and not returned from the API
|
||||
type StatusError struct {
|
||||
StatusCode int // e.g. 200
|
||||
Status string // e.g. "200 OK"
|
||||
ErrorMessage string `json:"error"`
|
||||
}
|
||||
|
||||
func (e StatusError) Error() string {
|
||||
switch {
|
||||
case e.Status != "" && e.ErrorMessage != "":
|
||||
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
|
||||
case e.Status != "":
|
||||
return e.Status
|
||||
case e.ErrorMessage != "":
|
||||
return e.ErrorMessage
|
||||
default:
|
||||
// this should not happen
|
||||
return "something went wrong, please see the ollama server logs for details"
|
||||
}
|
||||
}
|
21
api/types.go
21
api/types.go
@ -12,27 +12,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// StatusError is an error with an HTTP status code and message.
|
||||
type StatusError struct {
|
||||
StatusCode int
|
||||
Status string
|
||||
ErrorMessage string `json:"error"`
|
||||
}
|
||||
|
||||
func (e StatusError) Error() string {
|
||||
switch {
|
||||
case e.Status != "" && e.ErrorMessage != "":
|
||||
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
|
||||
case e.Status != "":
|
||||
return e.Status
|
||||
case e.ErrorMessage != "":
|
||||
return e.ErrorMessage
|
||||
default:
|
||||
// this should not happen
|
||||
return "something went wrong, please see the ollama server logs for details"
|
||||
}
|
||||
}
|
||||
|
||||
// ImageData represents the raw binary data of an image file.
|
||||
type ImageData []byte
|
||||
|
||||
|
69
cmd/cmd.go
69
cmd/cmd.go
@ -19,7 +19,6 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -41,7 +40,6 @@ import (
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@ -516,46 +514,22 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
func errFromUnknownKey(unknownKeyErr error) error {
|
||||
// find SSH public key in the error message
|
||||
// TODO (brucemacd): the API should return structured errors so that this message parsing isn't needed
|
||||
sshKeyPattern := `ssh-\w+ [^\s"]+`
|
||||
re := regexp.MustCompile(sshKeyPattern)
|
||||
matches := re.FindStringSubmatch(unknownKeyErr.Error())
|
||||
|
||||
if len(matches) > 0 {
|
||||
serverPubKey := matches[0]
|
||||
|
||||
localPubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||
// try the ollama service public key
|
||||
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
localPubKey = strings.TrimSpace(string(svcPubKey))
|
||||
}
|
||||
|
||||
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
|
||||
if serverPubKey != localPubKey {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
var msg strings.Builder
|
||||
msg.WriteString(unknownKeyErr.Error())
|
||||
msg.WriteString("\n\nYour ollama key is:\n")
|
||||
msg.WriteString(localPubKey)
|
||||
msg.WriteString("\nAdd your key at:\n")
|
||||
msg.WriteString("https://ollama.com/settings/keys")
|
||||
|
||||
return errors.New(msg.String())
|
||||
func localPubKeys() ([]string, error) {
|
||||
usrKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return unknownKeyErr
|
||||
keys := []string{usrKey}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
// try the ollama service public key if on Linux
|
||||
if svcKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub"); err == nil {
|
||||
keys = append(keys, strings.TrimSpace(string(svcKey)))
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
@ -611,15 +585,18 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
var ke api.ErrUnknownOllamaKey
|
||||
if errors.As(err, &ke) && isOllamaHost {
|
||||
|
||||
// the user has not added their ollama key to ollama.com
|
||||
// return an error with a more user-friendly message
|
||||
locals, _ := localPubKeys()
|
||||
return errors.New(ke.FormatUserMessage(locals))
|
||||
}
|
||||
if strings.Contains(err.Error(), "access denied") {
|
||||
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
||||
}
|
||||
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
|
||||
// the user has not added their ollama key to ollama.com
|
||||
// return an error with a more user-friendly message
|
||||
return errFromUnknownKey(err)
|
||||
}
|
||||
return err
|
||||
return fmt.Errorf("yoyoyo: %w", err)
|
||||
}
|
||||
|
||||
p.Stop()
|
||||
|
@ -16,7 +16,6 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
)
|
||||
|
||||
func TestShowInfo(t *testing.T) {
|
||||
@ -437,7 +436,7 @@ func TestPushHandler(t *testing.T) {
|
||||
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
uerr := errtypes.UnknownOllamaKey{
|
||||
uerr := api.ErrUnknownOllamaKey{
|
||||
Key: "aaa",
|
||||
}
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
|
@ -19,7 +19,6 @@ import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
)
|
||||
|
||||
type MultilineState int
|
||||
@ -220,7 +219,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fn := func(resp api.ProgressResponse) error { return nil }
|
||||
err = client.Create(cmd.Context(), req, fn)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), errtypes.InvalidModelNameErrMsg) {
|
||||
if strings.Contains(err.Error(), api.InvalidModelNameErrMsg) {
|
||||
fmt.Printf("error: The model name '%s' is invalid\n", args[1])
|
||||
continue
|
||||
}
|
||||
|
@ -30,7 +30,6 @@ import (
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/types/registry"
|
||||
"github.com/ollama/ollama/version"
|
||||
@ -1031,7 +1030,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||
return nil, re
|
||||
}
|
||||
return nil, errtypes.UnknownOllamaKey{
|
||||
return nil, api.ErrUnknownOllamaKey{
|
||||
Key: pubKey,
|
||||
}
|
||||
}
|
||||
|
@ -36,7 +36,6 @@ import (
|
||||
"github.com/ollama/ollama/runners"
|
||||
"github.com/ollama/ollama/server/imageproc"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@ -610,7 +609,7 @@ func (s *Server) PushHandler(c *gin.Context) {
|
||||
defer cancel()
|
||||
|
||||
if err := PushModel(ctx, model, regOpts, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
ch <- newErr(err)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -650,7 +649,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
|
||||
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||
if !name.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": api.InvalidModelNameErrMsg})
|
||||
return
|
||||
}
|
||||
|
||||
@ -1550,3 +1549,24 @@ func handleScheduleError(c *gin.Context, name string, err error) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
// newErr creates a structured API ErrorResponse from an existing error
|
||||
func newErr(err error) api.ErrorResponse {
|
||||
if err == nil {
|
||||
return api.ErrorResponse{}
|
||||
}
|
||||
// Default to just returning the generic error message
|
||||
resp := api.ErrorResponse{
|
||||
Code: api.ErrCodeGeneral,
|
||||
Message: err.Error(),
|
||||
}
|
||||
// Add additional error specific data, if any
|
||||
var errResp api.ErrUnknownOllamaKey
|
||||
if errors.As(err, &errResp) {
|
||||
resp.Code = api.ErrCodeUnknownKey
|
||||
resp.Data = map[string]any{
|
||||
"key": errResp.Key,
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
@ -1,21 +0,0 @@
|
||||
// Package errtypes contains custom error types
|
||||
package errtypes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
UnknownOllamaKeyErrMsg = "unknown ollama key"
|
||||
InvalidModelNameErrMsg = "invalid model name"
|
||||
)
|
||||
|
||||
// TODO: This should have a structured response from the API
|
||||
type UnknownOllamaKey struct {
|
||||
Key string
|
||||
}
|
||||
|
||||
func (e UnknownOllamaKey) Error() string {
|
||||
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user