mirror of
https://github.com/ollama/ollama.git
synced 2025-03-27 02:01:56 +01:00
Replace large-chunk blob downloads with parallel small-chunk verification to solve timeout and performance issues. Registry users experienced progressively slowing download speeds as large-chunk transfers aged, often timing out completely. The previous approach downloaded blobs in a few large chunks but required a separate, single-threaded pass to read the entire blob back from disk for verification after download completion. This change uses the new chunksums API to fetch many smaller chunk+digest pairs, allowing concurrent downloads and immediate verification as each chunk arrives. Chunks are written directly to their final positions, eliminating the entire separate verification pass. The result is more reliable downloads that maintain speed throughout the transfer process and significantly faster overall completion, especially over unstable connections or with large blobs.
300 lines
8.5 KiB
Go
300 lines
8.5 KiB
Go
package registry
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"io/fs"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
|
"github.com/ollama/ollama/server/internal/testutil"
|
|
"golang.org/x/tools/txtar"
|
|
|
|
_ "embed"
|
|
)
|
|
|
|
type panicTransport struct{}
|
|
|
|
func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
|
panic("unexpected RoundTrip call")
|
|
}
|
|
|
|
var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
|
|
|
|
// bytesResetter is an interface for types that can be reset and return a byte
|
|
// slice, only. This is to prevent inadvertent use of bytes.Buffer.Read/Write
|
|
// etc for the purpose of checking logs.
|
|
type bytesResetter interface {
|
|
Bytes() []byte
|
|
Reset()
|
|
}
|
|
|
|
func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
|
|
t.Helper()
|
|
dir := t.TempDir()
|
|
err := os.CopyFS(dir, os.DirFS("testdata/models"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
c, err := blob.Open(dir)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
client := panicOnRoundTrip
|
|
if upstreamRegistry != nil {
|
|
s := httptest.NewTLSServer(upstreamRegistry)
|
|
t.Cleanup(s.Close)
|
|
tr := s.Client().Transport.(*http.Transport).Clone()
|
|
tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
|
var d net.Dialer
|
|
return d.DialContext(ctx, "tcp", s.Listener.Addr().String())
|
|
}
|
|
client = &http.Client{Transport: tr}
|
|
}
|
|
|
|
rc := &ollama.Registry{
|
|
Cache: c,
|
|
HTTPClient: client,
|
|
Mask: "example.com/library/_:latest",
|
|
}
|
|
|
|
l := &Local{
|
|
Client: rc,
|
|
Logger: testutil.Slogger(t),
|
|
}
|
|
return l
|
|
}
|
|
|
|
func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
|
|
t.Helper()
|
|
req := httptest.NewRequestWithContext(t.Context(), method, path, strings.NewReader(body))
|
|
return s.sendRequest(t, req)
|
|
}
|
|
|
|
func (s *Local) sendRequest(t *testing.T, req *http.Request) *httptest.ResponseRecorder {
|
|
t.Helper()
|
|
w := httptest.NewRecorder()
|
|
s.ServeHTTP(w, req)
|
|
return w
|
|
}
|
|
|
|
type invalidReader struct{}
|
|
|
|
func (r *invalidReader) Read(p []byte) (int, error) {
|
|
return 0, os.ErrInvalid
|
|
}
|
|
|
|
// captureLogs is a helper to capture logs from the server. It returns a
|
|
// shallow copy of the server with a new logger and a bytesResetter for the
|
|
// logs.
|
|
func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
|
|
t.Helper()
|
|
log, logs := testutil.SlogBuffer()
|
|
l := *s // shallow copy
|
|
l.Logger = log
|
|
return &l, logs
|
|
}
|
|
|
|
func TestServerDelete(t *testing.T) {
|
|
check := testutil.Checker(t)
|
|
|
|
s := newTestServer(t, nil)
|
|
|
|
_, err := s.Client.ResolveLocal("smol")
|
|
check(err)
|
|
|
|
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
|
|
if got.Code != 200 {
|
|
t.Fatalf("Code = %d; want 200", got.Code)
|
|
}
|
|
|
|
_, err = s.Client.ResolveLocal("smol")
|
|
if err == nil {
|
|
t.Fatal("expected smol to have been deleted")
|
|
}
|
|
|
|
got = s.send(t, "DELETE", "/api/delete", `!`)
|
|
checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
|
|
|
|
got = s.send(t, "GET", "/api/delete", `{"model": "smol"}`)
|
|
checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
|
|
|
|
got = s.send(t, "DELETE", "/api/delete", ``)
|
|
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
|
|
|
|
got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
|
|
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
|
|
|
|
got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
|
|
checkErrorResponse(t, got, 404, "not_found", "not found")
|
|
|
|
s, logs := captureLogs(t, s)
|
|
req := httptest.NewRequestWithContext(t.Context(), "DELETE", "/api/delete", &invalidReader{})
|
|
got = s.sendRequest(t, req)
|
|
checkErrorResponse(t, got, 500, "internal_error", "internal server error")
|
|
ok, err := regexp.Match(`ERROR.*error="invalid argument"`, logs.Bytes())
|
|
check(err)
|
|
if !ok {
|
|
t.Logf("logs:\n%s", logs)
|
|
t.Fatalf("expected log to contain ERROR with invalid argument")
|
|
}
|
|
}
|
|
|
|
//go:embed testdata/registry.txt
|
|
var registryTXT []byte
|
|
|
|
var registryFS = sync.OnceValue(func() fs.FS {
|
|
// Txtar gets hung up on \r\n line endings, so we need to convert them
|
|
// to \n when parsing the txtar on Windows.
|
|
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
|
|
a := txtar.Parse(data)
|
|
fsys, err := txtar.FS(a)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return fsys
|
|
})
|
|
|
|
func TestServerPull(t *testing.T) {
|
|
modelsHandler := http.FileServerFS(registryFS())
|
|
s := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/v2/library/BOOM/manifests/latest":
|
|
w.WriteHeader(999)
|
|
io.WriteString(w, `{"error": "boom"}`)
|
|
case "/v2/library/unknown/manifests/latest":
|
|
w.WriteHeader(404)
|
|
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
|
|
default:
|
|
t.Logf("serving blob: %s", r.URL.Path)
|
|
modelsHandler.ServeHTTP(w, r)
|
|
}
|
|
})
|
|
|
|
checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
|
|
t.Helper()
|
|
|
|
if got.Code != 200 {
|
|
t.Errorf("Code = %d; want 200", got.Code)
|
|
}
|
|
gotlines := got.Body.String()
|
|
t.Logf("got:\n%s", gotlines)
|
|
for want := range strings.Lines(wantlines) {
|
|
want = strings.TrimSpace(want)
|
|
want, unwanted := strings.CutPrefix(want, "!")
|
|
want = strings.TrimSpace(want)
|
|
if !unwanted && !strings.Contains(gotlines, want) {
|
|
t.Errorf("! missing %q in body", want)
|
|
}
|
|
if unwanted && strings.Contains(gotlines, want) {
|
|
t.Errorf("! unexpected %q in body", want)
|
|
}
|
|
}
|
|
}
|
|
|
|
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
|
|
checkResponse(got, `
|
|
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
|
|
`)
|
|
|
|
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
|
|
checkResponse(got, `
|
|
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
|
|
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
|
|
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
|
|
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
|
|
`)
|
|
|
|
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
|
|
checkResponse(got, `
|
|
{"status":"error: model \"unknown\" not found"}
|
|
`)
|
|
|
|
got = s.send(t, "DELETE", "/api/pull", `{"model": "smol"}`)
|
|
checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
|
|
|
|
got = s.send(t, "POST", "/api/pull", `!`)
|
|
checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
|
|
|
|
got = s.send(t, "POST", "/api/pull", ``)
|
|
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
|
|
|
|
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
|
|
checkResponse(got, `
|
|
{"status":"error: invalid or missing name: \"\""}
|
|
`)
|
|
|
|
// Non-streaming pulls
|
|
got = s.send(t, "POST", "/api/pull", `{"model": "://", "stream": false}`)
|
|
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
|
|
got = s.send(t, "POST", "/api/pull", `{"model": "smol", "stream": false}`)
|
|
checkResponse(got, `
|
|
{"status":"success"}
|
|
!digest
|
|
!total
|
|
!completed
|
|
`)
|
|
got = s.send(t, "POST", "/api/pull", `{"model": "unknown", "stream": false}`)
|
|
checkErrorResponse(t, got, 404, "not_found", "model not found")
|
|
}
|
|
|
|
func TestServerUnknownPath(t *testing.T) {
|
|
s := newTestServer(t, nil)
|
|
got := s.send(t, "DELETE", "/api/unknown", `{}`)
|
|
checkErrorResponse(t, got, 404, "not_found", "not found")
|
|
|
|
var fellback bool
|
|
s.Fallback = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fellback = true
|
|
})
|
|
got = s.send(t, "DELETE", "/api/unknown", `{}`)
|
|
if !fellback {
|
|
t.Fatal("expected Fallback to be called")
|
|
}
|
|
if got.Code != 200 {
|
|
t.Fatalf("Code = %d; want 200", got.Code)
|
|
}
|
|
}
|
|
|
|
func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
|
|
t.Helper()
|
|
|
|
var printedBody bool
|
|
errorf := func(format string, args ...any) {
|
|
t.Helper()
|
|
if !printedBody {
|
|
t.Logf("BODY:\n%s", got.Body.String())
|
|
printedBody = true
|
|
}
|
|
t.Errorf(format, args...)
|
|
}
|
|
|
|
if got.Code != status {
|
|
errorf("Code = %d; want %d", got.Code, status)
|
|
}
|
|
|
|
// unmarshal the error as *ollama.Error (proving *serverError is an *ollama.Error)
|
|
var e *ollama.Error
|
|
if err := json.Unmarshal(got.Body.Bytes(), &e); err != nil {
|
|
errorf("unmarshal error: %v", err)
|
|
t.FailNow()
|
|
}
|
|
if e.Code != code {
|
|
errorf("Code = %q; want %q", e.Code, code)
|
|
}
|
|
if !strings.Contains(e.Message, msg) {
|
|
errorf("Message = %q; want to contain %q", e.Message, msg)
|
|
}
|
|
}
|