server: show user feedback when key is anonymous

When an ollama key is not registered with any account on ollama.com this is
not obvious. In the current CLI an error message that the user is not
authorized is displayed. This change brings back previous behavior to show
the user their key and where they should add it. It protects against adding
unexpected keys by checking that the key is available locally.

A follow-up change should add structured errors from the API. This change
just relies on a known error message.
This commit is contained in:
Bruce MacDonald 2024-11-27 15:01:12 -08:00
parent 08a832b482
commit 85822544a9
6 changed files with 247 additions and 13 deletions

View File

@ -20,6 +20,7 @@ import (
"os"
"os/signal"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
@ -35,6 +36,7 @@ import (
"golang.org/x/term"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llama"
@ -42,6 +44,7 @@ import (
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
@ -516,6 +519,48 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generate(cmd, opts)
}
func errFromUnknownKey(unknownKeyErr error) error {
// find SSH public key in the error message
// TODO (brucemacd): the API should return structured errors so that this message parsing isn't needed
sshKeyPattern := `ssh-\w+ [^\s"]+`
re := regexp.MustCompile(sshKeyPattern)
matches := re.FindStringSubmatch(unknownKeyErr.Error())
if len(matches) > 0 {
serverPubKey := matches[0]
localPubKey, err := auth.GetPublicKey()
if err != nil {
return unknownKeyErr
}
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
// try the ollama service public key
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
if err != nil {
return unknownKeyErr
}
localPubKey = strings.TrimSpace(string(svcPubKey))
}
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
if serverPubKey != localPubKey {
return unknownKeyErr
}
var msg strings.Builder
msg.WriteString(unknownKeyErr.Error())
msg.WriteString("\n\nYour ollama key is:\n")
msg.WriteString(localPubKey)
msg.WriteString("\nAdd your key at:\n")
msg.WriteString("https://ollama.com/settings/keys")
return errors.New(msg.String())
}
return unknownKeyErr
}
func PushHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
@ -564,6 +609,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
request := api.PushRequest{Name: args[0], Insecure: insecure}
n := model.ParseName(args[0])
isOllamaHost := strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com")
if err := client.Push(cmd.Context(), &request, fn); err != nil {
if spinner != nil {
spinner.Stop()
@ -571,6 +617,11 @@ func PushHandler(cmd *cobra.Command, args []string) error {
if strings.Contains(err.Error(), "access denied") {
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
}
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
// the user has not added their ollama key to ollama.com
// return an error with a more user-friendly message
return errFromUnknownKey(err)
}
return err
}

View File

@ -15,6 +15,7 @@ import (
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/errtypes"
)
func TestShowInfo(t *testing.T) {
@ -368,15 +369,13 @@ func TestGetModelfileName(t *testing.T) {
func TestPushHandler(t *testing.T) {
tests := []struct {
name string
modelName string
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
expectedError string
expectedOutput string
}{
{
name: "successful push",
modelName: "test-model",
modelName: "successful-push",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/push": func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
@ -389,8 +388,8 @@ func TestPushHandler(t *testing.T) {
return
}
if req.Name != "test-model" {
t.Errorf("expected model name 'test-model', got %s", req.Name)
if req.Name != "successful-push" {
t.Errorf("expected model name 'successful-push', got %s", req.Name)
}
// Simulate progress updates
@ -409,11 +408,10 @@ func TestPushHandler(t *testing.T) {
}
},
},
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/successful-push\n",
},
{
name: "unauthorized push",
modelName: "unauthorized-model",
modelName: "unauthorized-push",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/push": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
@ -428,10 +426,29 @@ func TestPushHandler(t *testing.T) {
},
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
},
{
modelName: "unknown-key-err",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/push": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
uerr := errtypes.UnknownOllamaKey{
Key: "aaa",
}
err := json.NewEncoder(w).Encode(map[string]string{
"error": uerr.Error(),
})
if err != nil {
t.Fatal(err)
}
},
},
expectedError: "unauthorized: unknown ollama key \"aaa\"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Run(tt.modelName, func(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if handler, ok := tt.serverResponse[r.URL.Path]; ok {
handler(w, r)

View File

@ -23,13 +23,16 @@ import (
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/registry"
"github.com/ollama/ollama/version"
)
@ -984,8 +987,6 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
var errUnauthorized = errors.New("unauthorized: access denied")
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
for range 2 {
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
@ -1023,13 +1024,33 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
if err != nil {
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
}
var re registry.Errs
if err := json.Unmarshal(responseBody, &re); err == nil && len(re.Errors) > 0 {
if re.HasCode(registry.ErrCodeAnonymous) {
// if the error is due to anonymous access return a custom error
// this error is used by the CLI to direct a user to add their key to an account
pubKey, nestedErr := auth.GetPublicKey()
if nestedErr != nil {
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
return nil, re
}
return nil, errtypes.UnknownOllamaKey{
Key: pubKey,
}
}
return nil, re
}
// Fallback to returning the raw response if parsing fails
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
default:
return resp, nil
}
}
return nil, errUnauthorized
// should never be reached
return nil, fmt.Errorf("failed to make upload request")
}
// testMakeRequestDialContext specifies the dial function for the http client in

107
server/images_test.go Normal file
View File

@ -0,0 +1,107 @@
package server
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
)
func TestMakeRequestWithRetry(t *testing.T) {
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{
"token": "test-token",
})
}))
defer authServer.Close()
tests := []struct {
name string
serverHandler http.HandlerFunc
method string
body string
wantErr error
wantStatus int
}{
{
name: "successful request",
serverHandler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
},
method: http.MethodGet,
wantStatus: http.StatusOK,
},
{
name: "not found error",
serverHandler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
},
method: http.MethodGet,
wantErr: os.ErrNotExist,
},
{
name: "request with body retry",
serverHandler: func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") == "" {
w.Header().Set("WWW-Authenticate", `Bearer realm="`+authServer.URL+`"`)
w.WriteHeader(http.StatusUnauthorized)
return
}
buf := new(bytes.Buffer)
buf.ReadFrom(r.Body)
if buf.String() != `{"key": "value"}` {
t.Errorf("body not preserved on retry, got %s", buf.String())
}
w.WriteHeader(http.StatusOK)
},
method: http.MethodPost,
body: `{"key": "value"}`,
wantStatus: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(tt.serverHandler)
defer server.Close()
requestURL, _ := url.Parse(server.URL)
var body io.ReadSeeker
if tt.body != "" {
body = strings.NewReader(tt.body)
}
regOpts := &registryOptions{
Insecure: true,
}
resp, err := makeRequestWithRetry(context.Background(), tt.method, requestURL, nil, body, regOpts)
if tt.wantErr != nil {
if !errors.Is(err, tt.wantErr) {
t.Errorf("got error %v, want %v", err, tt.wantErr)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != tt.wantStatus {
t.Errorf("got status %d, want %d", resp.StatusCode, tt.wantStatus)
}
resp.Body.Close()
})
}
}

View File

@ -16,6 +16,6 @@ type UnknownOllamaKey struct {
Key string
}
func (e *UnknownOllamaKey) Error() string {
func (e UnknownOllamaKey) Error() string {
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
}

38
types/registry/error.go Normal file
View File

@ -0,0 +1,38 @@
package registry
import (
"fmt"
"slices"
"strings"
)
const ErrCodeAnonymous = "ANONYMOUS_ACCESS_DENIED"
type Err struct {
Code string `json:"code"`
Message string `json:"message"`
}
// Errs represents the structure of error responses from the registry
// TODO (brucemacd): this struct should be imported from some shared package that is used between the registry and ollama
type Errs struct {
Errors []Err `json:"errors"`
}
// Error implements the error interface for RegistryError
func (e Errs) Error() string {
if len(e.Errors) == 0 {
return "unknown registry error"
}
var msgs []string
for _, err := range e.Errors {
msgs = append(msgs, fmt.Sprintf("%s: %s", err.Code, err.Message))
}
return strings.Join(msgs, "; ")
}
func (e Errs) HasCode(code string) bool {
return slices.ContainsFunc(e.Errors, func(err Err) bool {
return err.Code == code
})
}