cmd: default to client2 and simplify pull progress display

Switch the pull and remove operations to use the client2 registry by default,
removing the need to pass an experimental flag.

Also simplify the progress UI during model pulls. Previously, each layer
displayed its own progress bar, resulting in noisy and repetitive output:

	pulling manifest
	pulling aeda25e63ebd... 100% ▕██████████████████████████ 3.3 GB
	pulling e0a42594d802... 100% ▕██████████████████████████  358 B
	...
	writing manifest
	success

This change replaces that with a single progress bar for the entire pull,
followed by a single "Done." message:

	Downloading gemma3: 100% ▕█████████████████████████████▏ 3.3 GB

This provides a cleaner and more intuitive experience, and aligns better
with how users think about pulling models as a unit, rather than a collection
of layers.

To support older clients that still rely on a fixed-width Digest field,
we format the Digest to be at least 20 characters long. The value includes
padding and a truncated model name to prevent out-of-bounds access in
legacy clients. This is a temporary compatibility hack and can be removed
once all clients have adopted the new API.

Updates server behavior to handle all combinations of new and old
clients.
This commit is contained in:
Blake Mizerany 2025-03-31 17:05:22 -07:00
parent 2f723ac2d6
commit 2c8f95de19
7 changed files with 246 additions and 261 deletions

View File

@ -814,7 +814,12 @@ func PullHandler(cmd *cobra.Command, args []string) error {
bar, ok := bars[resp.Digest]
if !ok {
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
name = strings.TrimSpace(name)
if isDigest {
name = name[:min(12, len(name))]
}
bar = progress.NewBar(fmt.Sprintf("Downloading %s:", name), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}

View File

@ -1,6 +1,5 @@
// Package ollama provides a client for interacting with an Ollama registry
// which pushes and pulls model manifests and layers as defined by the
// [ollama.com/manifest].
// Package ollama implements a client for the Ollama registry API.
// It handles pushing and pulling model manifests and layers as per [ollama.com/manifest].
package ollama
import (
@ -44,33 +43,28 @@ import (
// Errors
var (
// ErrModelNotFound is returned when a manifest is not found in the
// cache or registry.
// ErrModelNotFound indicates a manifest is missing from a cache or registry.
ErrModelNotFound = errors.New("model not found")
// ErrManifestInvalid is returned when a manifest found in a local or
// remote cache is invalid.
// ErrManifestInvalid indicates a manifest is structurally invalid.
ErrManifestInvalid = errors.New("invalid manifest")
// ErrMissingModel is returned when the model part of a name is missing
// or invalid.
// ErrNameInvalid indicates the model name is missing or invalid.
ErrNameInvalid = errors.New("invalid or missing name")
// ErrCached is passed to [Trace.PushUpdate] when a layer already
// exists. It is a non-fatal error and is never returned by [Registry.Push].
// ErrCached signals to [Trace.Update] that a layer exists and was skipped.
// Never returned by [Registry.Push].
ErrCached = errors.New("cached")
// ErrIncomplete is returned by [Registry.Pull] when a model pull was
// incomplete due to one or more layer download failures. Users that
// want specific errors should use [WithTrace].
// ErrIncomplete indicates [Registry.Pull] failed to download one or more layers.
// Use [WithTrace] for detailed error information.
ErrIncomplete = errors.New("incomplete")
)
// Defaults
const (
// DefaultChunkingThreshold is the threshold at which a layer should be
// split up into chunks when downloading.
DefaultChunkingThreshold = 64 << 20
// DefaultChunkingThreshold defines when to download a layer in chunks.
DefaultChunkingThreshold = 64 << 20 // 64MB
)
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
@ -83,39 +77,34 @@ var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
return blob.Open(dir)
})
// DefaultCache returns the default cache used by the registry. It is
// configured from the OLLAMA_MODELS environment variable, or defaults to
// $HOME/.ollama/models, or, if an error occurs obtaining the home directory,
// it uses the current working directory.
// DefaultCache returns the default model cache.
//
// The cache directory is determined by the OLLAMA_MODELS environment variable,
// if present; otherwise, it uses the $HOME/.ollama/models, or the current
// working directory if $HOME cannot be determined.
func DefaultCache() (*blob.DiskCache, error) {
return defaultCache()
}
// Error is the standard error returned by Ollama APIs. It can represent a
// single or multiple error response.
//
// Single error responses have the following format:
//
// {"code": "optional_code","error":"error message"}
//
// Multiple error responses have the following format:
//
// {"errors": [{"code": "optional_code","message":"error message"}]}
//
// Note, that the error field is used in single error responses, while the
// message field is used in multiple error responses.
//
// In both cases, the code field is optional and may be empty.
// Error represents API error responses in two formats:
// Single: {"code": "optional_code","error":"error message"}
// Multiple: {"errors": [{"code": "optional_code","message":"error message"}]}
// The code field is optional in both formats.
type Error struct {
Status int `json:"-"` // TODO(bmizerany): remove this
status int `json:"-"` // TODO(bmizerany): remove this
Code string `json:"code"`
Message string `json:"message"`
}
// Temporary reports if the error is temporary (e.g. 5xx status code).
func (e *Error) Temporary() bool {
return e.status > 500
}
func (e *Error) Error() string {
var b strings.Builder
b.WriteString("registry responded with status ")
b.WriteString(strconv.Itoa(e.Status))
b.WriteString(strconv.Itoa(e.status))
if e.Code != "" {
b.WriteString(": code ")
b.WriteString(e.Code)
@ -129,7 +118,7 @@ func (e *Error) Error() string {
func (e *Error) LogValue() slog.Value {
return slog.GroupValue(
slog.Int("status", e.Status),
slog.Int("status", e.status),
slog.String("code", e.Code),
slog.String("message", e.Message),
)
@ -172,22 +161,19 @@ var defaultMask = func() names.Name {
return n
}()
// CompleteName returns a fully qualified name by merging the given name with
// the default mask. If the name is already fully qualified, it is returned
// unchanged.
// CompleteName ensures a name is fully qualified by applying DefaultMask if needed.
func CompleteName(name string) string {
return names.Merge(names.Parse(name), defaultMask).String()
}
// Registry is a client for performing push and pull operations against an
// Ollama registry.
// Registry handles Ollama registry operations.
type Registry struct {
// Cache is the cache used to store models. If nil, [DefaultCache] is
// used.
// Cache is the cache for storing models and their blobs.
// If nil, [DefaultCache] is used.
Cache *blob.DiskCache
// UserAgent is the User-Agent header to send with requests to the
// registry. If empty, the User-Agent is determined by HTTPClient.
// UserAgent sent in HTTP requests.
// If empty, HTTPClient's User-Agent is used.
UserAgent string
// Key is the key used to authenticate with the registry.
@ -195,30 +181,26 @@ type Registry struct {
// Currently, only Ed25519 keys are supported.
Key crypto.PrivateKey
// HTTPClient is the HTTP client used to make requests to the registry.
//
// HTTPClient specifies the [http.Client] for performing registry requests.
// If nil, [http.DefaultClient] is used.
//
// As a quick note: If a Registry function that makes a call to a URL
// with the "https+insecure" scheme, the client will be cloned and the
// transport will be set to skip TLS verification, unless the client's
// Transport done not have a Clone method with the same signature as
// [http.Transport.Clone], which case, the call will fail.
// If a user uses the "https+insecure" scheme in a URL, the client's
// Transport is cloned with InsecureSkipVerify set to true.
// If the Transport does not support cloning, the request will be
// passed, as-is, to HTTPClient.
HTTPClient *http.Client
// MaxStreams is the maximum number of concurrent streams to use when
// pushing or pulling models. If zero, the number of streams is
// determined by [runtime.GOMAXPROCS].
//
// A negative value means no limit.
// MaxStreams limits concurrent transfers.
// If zero, [runtime.GOMAXPROCS] is used.
// If negative, no limit is applied.
MaxStreams int
// ChunkingThreshold is the maximum size of a layer to download in a single
// request. If zero, [DefaultChunkingThreshold] is used.
// ChunkingThreshold defines max layer size for single download.
// If zero, [DefaultChunkingThreshold] is used.
ChunkingThreshold int64
// Mask, if set, is the name used to convert non-fully qualified names
// to fully qualified names. If empty, [DefaultMask] is used.
// Mask completes partial model names.
// If empty, [DefaultMask] is used.
Mask string
}
@ -304,12 +286,11 @@ func (r *Registry) maxChunkingThreshold() int64 {
}
type PushParams struct {
// From is an optional destination name for the model. If empty, the
// destination name is the same as the source name.
// From specifies an alternative source name. If empty, target name is used.
From string
}
// Push pushes the model with the name in the cache to the remote registry.
// Push uploads a locally cached model to the remote registry.
func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
if p == nil {
p = &PushParams{}
@ -437,13 +418,14 @@ func (r *trackingReader) Read(p []byte) (n int, err error) {
return
}
// Pull pulls the model with the given name from the remote registry into the
// cache.
// Pull copies the named model from the remote registry to the Cache.
// Layers exceeding ChunkingThreshold in size are downloaded in chunks.
// Each layer and chunk is downloaded in its own goroutine.
// The maximum number of download goroutines is limited by MaxStreams.
// If a layer or chunk is already in the Cache, it is skipped.
//
// For layers larger then [Registry.MaxChunkSize], the layer is downloaded in
// chunks of the specified size, and then reassembled and verified. This is
// typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image".
// Clients that need progress updates can use the [WithTrace] option to
// register callbacks for updates.
func (r *Registry) Pull(ctx context.Context, name string) error {
m, err := r.Resolve(ctx, name)
if err != nil {
@ -489,94 +471,93 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
continue
}
var wg sync.WaitGroup
chunked, err := c.Chunked(l.Digest, l.Size)
if err != nil {
t.update(l, 0, err)
continue
}
for cs, err := range r.chunksums(ctx, name, l) {
func() {
var wg sync.WaitGroup
chunked, err := c.Chunked(l.Digest, l.Size)
if err != nil {
// Chunksum stream interrupted. Note in trace
// log and let in-flight downloads complete.
// This will naturally trigger ErrIncomplete
// since received < expected bytes.
t.update(l, 0, err)
break
return
}
defer func() {
// Close the chunked writer when all chunks are
// downloaded.
//
// This is done as a background task in the
// group to allow the next layer to start while
// we wait for the final chunk in this layer to
// complete. It also ensures this is done
// before we exit Pull.
g.Go(func() error {
wg.Wait()
chunked.Close()
return nil
})
}()
cacheKey := fmt.Sprintf(
"v1 pull chunksum %s %s %d-%d",
l.Digest,
cs.Digest,
cs.Chunk.Start,
cs.Chunk.End,
)
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
_, err := c.Get(cacheKeyDigest)
if err == nil {
received.Add(cs.Chunk.Size())
t.update(l, cs.Chunk.Size(), ErrCached)
continue
}
for cs, err := range r.chunksums(ctx, name, l) {
if err != nil {
// Chunksum stream interrupted. Note in trace
// log and let in-flight downloads complete.
// This will naturally trigger ErrIncomplete
// since received < expected bytes.
t.update(l, 0, err)
break
}
wg.Add(1)
g.Go(func() (err error) {
defer func() {
if err == nil {
// Ignore cache key write errors for now. We've already
// reported to trace that the chunk is complete.
//
// Ideally, we should only report completion to trace
// after successful cache commit. This current approach
// works but could trigger unnecessary redownloads if
// the checkpoint key is missing on next pull.
//
// Not incorrect, just suboptimal - fix this in a
// future update.
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
cacheKey := fmt.Sprintf(
"v1 pull chunksum %s %s %d-%d",
l.Digest,
cs.Digest,
cs.Chunk.Start,
cs.Chunk.End,
)
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
_, err := c.Get(cacheKeyDigest)
if err == nil {
received.Add(cs.Chunk.Size())
t.update(l, cs.Chunk.Size(), ErrCached)
continue
}
received.Add(cs.Chunk.Size())
} else {
t.update(l, 0, err)
wg.Add(1)
g.Go(func() (err error) {
defer func() {
if err == nil {
// Ignore cache key write errors for now. We've already
// reported to trace that the chunk is complete.
//
// Ideally, we should only report completion to trace
// after successful cache commit. This current approach
// works but could trigger unnecessary redownloads if
// the checkpoint key is missing on next pull.
//
// Not incorrect, just suboptimal - fix this in a
// future update.
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
received.Add(cs.Chunk.Size())
} else {
t.update(l, 0, err)
}
wg.Done()
}()
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
}
wg.Done()
}()
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
body := &trackingReader{l: l, r: res.Body, update: t.update}
return chunked.Put(cs.Chunk, cs.Digest, body)
})
}
// Close writer immediately after downloads finish, not at Pull
// exit. Using defer would keep file descriptors open until all
// layers complete, potentially exhausting system limits with
// many layers.
//
// The WaitGroup tracks when all chunks finish downloading,
// allowing precise writer closure in a background goroutine.
// Each layer briefly uses one extra goroutine while at most
// maxStreams()-1 chunks download in parallel.
//
// This caps file descriptors at maxStreams() instead of
// growing with layer count.
g.Go(func() error {
wg.Wait()
chunked.Close()
return nil
})
body := &trackingReader{l: l, r: res.Body, update: t.update}
return chunked.Put(cs.Chunk, cs.Digest, body)
})
}
}()
}
if err := g.Wait(); err != nil {
return err
@ -592,8 +573,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return c.Link(m.Name, md)
}
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
// before attempting to unlink the model.
// Unlink removes the named model from Cache.
func (r *Registry) Unlink(name string) (ok bool, _ error) {
n, err := r.parseName(name)
if err != nil {
@ -606,14 +586,12 @@ func (r *Registry) Unlink(name string) (ok bool, _ error) {
return c.Unlink(n.String())
}
// Manifest represents a [ollama.com/manifest].
// Manifest represents a model manifest per [ollama.com/manifest].
type Manifest struct {
Name string `json:"-"` // the canonical name of the model
Data []byte `json:"-"` // the raw data of the manifest
Name string `json:"-"` // Canonical model name
Data []byte `json:"-"` // Raw manifest data
Layers []*Layer `json:"layers"`
// For legacy reasons, we still have to download the config layer.
Config *Layer `json:"config"`
Config *Layer `json:"config"` // Legacy requirement
}
// Layer returns the layer with the given
@ -691,14 +669,14 @@ func unmarshalManifest(n names.Name, data []byte) (*Manifest, error) {
return &m, nil
}
// Layer is a layer in a model.
// Layer represents a model component with its metadata.
type Layer struct {
Digest blob.Digest `json:"digest"`
MediaType string `json:"mediaType"`
Size int64 `json:"size"`
Digest blob.Digest `json:"digest"` // Content hash
MediaType string `json:"mediaType"` // Content type
Size int64 `json:"size"` // Size in bytes
}
// ResolveLocal resolves a name to a Manifest in the local cache.
// ResolveLocal resolves name to a Manifest in the Cache.
func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
_, n, d, err := r.parseNameExtended(name)
if err != nil {
@ -973,7 +951,7 @@ func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error)
return nil, ErrModelNotFound
}
re.Status = res.StatusCode
re.status = res.StatusCode
return nil, &re
}
return res, nil
@ -1069,17 +1047,18 @@ var supportedSchemes = []string{
var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
// parseNameExtended parses and validates an extended name, returning the scheme, name,
// and digest.
// parseNameExtended parses and validates an extended name, returning the
// scheme, name, and digest.
//
// If the scheme is empty, scheme will be "https". If an unsupported scheme is
// given, [ErrNameInvalid] wrapped with a display friendly message is returned.
// If the scheme is empty, the returned scheme will be "https".
// If an unsupported scheme is given, [ErrNameInvalid] wrapped with a display
// friendly message is returned.
//
// If the digest is invalid, [ErrNameInvalid] wrapped with a display friendly
// message is returned.
//
// If the name is not, once merged with the mask, fully qualified,
// [ErrNameInvalid] wrapped with a display friendly message is returned.
// If the name after masking is not fully qualified, [ErrNameInvalid] wrapped
// with a display friendly message is returned.
func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
scheme, name, digest := splitExtended(s)
scheme = cmp.Or(scheme, "https")

View File

@ -154,7 +154,7 @@ func okHandler(w http.ResponseWriter, r *http.Request) {
func checkErrCode(t *testing.T, err error, status int, code string) {
t.Helper()
var e *Error
if !errors.As(err, &e) || e.Status != status || e.Code != code {
if !errors.As(err, &e) || e.status != status || e.Code != code {
t.Errorf("err = %v; want %v %v", err, status, code)
}
}

View File

@ -4,25 +4,17 @@ import (
"context"
)
// Trace is a set of functions that are called to report progress during blob
// downloads and uploads.
//
// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push]
// and [Registry.Pull].
// Trace reports progress during model transfers.
// Attach using [WithTrace] for use with [Registry.Push] and [Registry.Pull].
type Trace struct {
// Update is called during [Registry.Push] and [Registry.Pull] to
// report the progress of blob uploads and downloads.
// Update reports transfer progress with different states:
// When n=0 and err=nil: transfer started
// When err=[ErrCached]: layer already exists, skipped
// When err!=nil: transfer failed
// When n=l.Size and err=nil: transfer completed
// Otherwise: n bytes transferred so far
//
// The n argument is the number of bytes transferred so far, and err is
// any error that has occurred. If n == 0, and err is nil, the download
// or upload has just started. If err is [ErrCached], the download or
// upload has been skipped because the blob is already present in the
// local cache or remote registry, respectively. Otherwise, if err is
// non-nil, the download or upload has failed. When l.Size == n, and
// err is nil, the download or upload has completed.
//
// A function assigned must be safe for concurrent use. The function is
// called synchronously and so should not block or take long to run.
// Must be safe for concurrent use and non-blocking.
Update func(_ *Layer, n int64, _ error)
}
@ -34,18 +26,32 @@ func (t *Trace) update(l *Layer, n int64, err error) {
type traceKey struct{}
// WithTrace returns a context derived from ctx that uses t to report trace
// events.
// WithTrace adds a trace to the context for transfer progress reporting.
func WithTrace(ctx context.Context, t *Trace) context.Context {
return context.WithValue(ctx, traceKey{}, t)
old := traceFromContext(ctx)
if old == t {
// No change, return the original context. This also prevents
// infinite recursion below, if the caller passes the same
// Trace.
return ctx
}
// Create a new Trace that wraps the old one, if any. If we used the
// same pointer t, we end up with a recursive structure.
composed := &Trace{
Update: func(l *Layer, n int64, err error) {
if old != nil {
old.update(l, n, err)
}
t.update(l, n, err)
},
}
return context.WithValue(ctx, traceKey{}, composed)
}
var emptyTrace = &Trace{}
// traceFromContext returns the Trace associated with ctx, or an empty Trace if
// none is found.
//
// It never returns nil.
// traceFromContext extracts the Trace from ctx or returns an empty non-nil Trace.
func traceFromContext(ctx context.Context) *Trace {
t, _ := ctx.Value(traceKey{}).(*Trace)
if t == nil {

View File

@ -9,13 +9,13 @@ import (
"fmt"
"io"
"log/slog"
"maps"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/internal/backoff"
)
// Local implements an http.Handler for handling local Ollama API model
@ -241,10 +241,10 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
}
type progressUpdateJSON struct {
Status string `json:"status,omitempty,omitzero"`
Digest blob.Digest `json:"digest,omitempty,omitzero"`
Total int64 `json:"total,omitempty,omitzero"`
Completed int64 `json:"completed,omitempty,omitzero"`
Status string `json:"status,omitempty,omitzero"`
Digest string `json:"digest,omitempty,omitzero"`
Total int64 `json:"total,omitempty,omitzero"`
Completed int64 `json:"completed,omitempty,omitzero"`
}
func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
@ -265,39 +265,27 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
}
return err
}
return enc.Encode(progressUpdateJSON{Status: "success"})
enc.Encode(progressUpdateJSON{Status: "success"})
return nil
}
maybeFlush := func() {
var total, completed atomic.Int64
flushProgress := func() {
enc.Encode(progressUpdateJSON{
// This is a hack to maintain support for older clients
// as per the requirements given for the task of
// bringing in the our new client. In the future, once
// all clients are up-to-date, we can remove this
// change our API.
Digest: fmt.Sprintf("%7s%-13s", "", p.model()),
Total: total.Load(),
Completed: completed.Load(),
})
fl, _ := w.(http.Flusher)
if fl != nil {
fl.Flush()
}
}
defer maybeFlush()
var mu sync.Mutex
progress := make(map[*ollama.Layer]int64)
progressCopy := make(map[*ollama.Layer]int64, len(progress))
flushProgress := func() {
defer maybeFlush()
// TODO(bmizerany): Flushing every layer in one update doesn't
// scale well. We could flush only the modified layers or track
// the full download. Needs further consideration, though it's
// fine for now.
mu.Lock()
maps.Copy(progressCopy, progress)
mu.Unlock()
for l, n := range progressCopy {
enc.Encode(progressUpdateJSON{
Digest: l.Digest,
Total: l.Size,
Completed: n,
})
}
}
defer flushProgress()
t := time.NewTicker(1000 * time.Hour) // "unstarted" timer
@ -307,7 +295,13 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
})
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
if n > 0 {
if err != nil && !errors.Is(err, ollama.ErrCached) {
s.Logger.Error("pulling", "model", p.model(), "error", err)
return
}
if n == 0 {
total.Add(l.Size)
} else {
// Block flushing progress updates until every
// layer is accounted for. Clients depend on a
// complete model size to calculate progress
@ -315,16 +309,27 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
// progress indicators would erratically jump
// as new layers are registered.
start()
completed.Add(n)
}
mu.Lock()
progress[l] += n
mu.Unlock()
},
})
done := make(chan error, 1)
go func() {
done <- s.Client.Pull(ctx, p.model())
go func() (err error) {
defer func() { done <- err }()
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
err := s.Client.Pull(ctx, p.model())
var oe *ollama.Error
if errors.As(err, &oe) && oe.Temporary() {
// already logged in trace
continue
}
return err
}
return nil
}()
for {
@ -341,6 +346,7 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
status = fmt.Sprintf("error: %v", err)
}
enc.Encode(progressUpdateJSON{Status: status})
return nil
}
return nil
}

View File

@ -78,7 +78,12 @@ func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
t.Helper()
req := httptest.NewRequestWithContext(t.Context(), method, path, strings.NewReader(body))
ctx := ollama.WithTrace(t.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
t.Logf("update: %s %d %v", l.Digest, n, err)
},
})
req := httptest.NewRequestWithContext(ctx, method, path, strings.NewReader(body))
return s.sendRequest(t, req)
}
@ -184,7 +189,6 @@ func TestServerPull(t *testing.T) {
checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
t.Helper()
if got.Code != 200 {
t.Errorf("Code = %d; want 200", got.Code)
}
@ -203,17 +207,11 @@ func TestServerPull(t *testing.T) {
}
}
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
got := s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
checkResponse(got, `
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
`)
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
checkResponse(got, `
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
{"digest":" smol ","total":8}
{"digest":" smol ","total":8,"completed":8}
{"digest":" smol ","total":8,"completed":8}
`)
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)

View File

@ -42,12 +42,6 @@ import (
"github.com/ollama/ollama/version"
)
func experimentEnabled(name string) bool {
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
}
var useClient2 = experimentEnabled("client2")
var mode string = gin.DebugMode
type Server struct {
@ -1275,12 +1269,9 @@ func Serve(ln net.Listener) error {
s := &Server{addr: ln.Addr()}
var rc *ollama.Registry
if useClient2 {
var err error
rc, err = ollama.DefaultRegistry()
if err != nil {
return err
}
rc, err = ollama.DefaultRegistry()
if err != nil {
return err
}
h, err := s.GenerateRoutes(rc)