ollama/server/images.go

823 lines
20 KiB
Go
Raw Permalink Normal View History

package server
import (
"bytes"
"context"
"crypto/sha256"
2023-08-28 20:50:24 -07:00
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"log/slog"
"net"
"net/http"
2023-08-21 18:38:31 -07:00
"net/url"
"os"
"path/filepath"
2023-08-21 18:24:42 -07:00
"runtime"
2024-05-21 21:30:52 -07:00
"slices"
2024-02-14 11:29:49 -08:00
"strconv"
"strings"
"github.com/ollama/ollama/api"
2024-06-04 11:53:23 -07:00
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/parser"
2024-06-10 14:54:42 -07:00
"github.com/ollama/ollama/template"
2024-04-16 16:22:38 -07:00
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
var (
errCapabilities = errors.New("does not support")
errCapabilityCompletion = errors.New("completion")
errCapabilityTools = errors.New("tools")
errCapabilityInsert = errors.New("insert")
)
2024-06-17 10:38:55 -07:00
2024-06-11 14:03:42 -07:00
type Capability string
2024-06-20 13:45:47 -07:00
const (
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
2024-06-20 13:45:47 -07:00
)
2024-06-11 14:03:42 -07:00
2024-02-14 11:29:49 -08:00
type registryOptions struct {
Insecure bool
Username string
Password string
Token string
CheckRedirect func(req *http.Request, via []*http.Request) error
2024-02-14 11:29:49 -08:00
}
type Model struct {
2023-11-30 10:30:23 -08:00
Name string `json:"name"`
2023-12-01 11:37:17 -08:00
Config ConfigV2
2023-11-30 10:30:23 -08:00
ShortName string
ModelPath string
2024-01-25 12:12:36 -08:00
ParentModel string
2023-11-30 10:30:23 -08:00
AdapterPaths []string
ProjectorPaths []string
System string
License []string
Digest string
Options map[string]interface{}
2024-06-19 14:14:28 -07:00
Messages []api.Message
2024-06-10 14:54:42 -07:00
Template *template.Template
2024-01-25 12:12:36 -08:00
}
2024-06-17 10:38:55 -07:00
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
// any missing or unknown capabilities
func (m *Model) CheckCapabilities(caps ...Capability) error {
var errs []error
2024-06-11 14:03:42 -07:00
for _, cap := range caps {
switch cap {
case CapabilityCompletion:
r, err := os.Open(m.ModelPath)
2024-06-14 14:57:49 -07:00
if err != nil {
slog.Error("couldn't open model file", "error", err)
continue
}
defer r.Close()
2024-06-14 14:57:49 -07:00
// TODO(mxyng): decode the GGML into model to avoid doing this multiple times
f, _, err := ggml.Decode(r, 0)
2024-06-14 14:57:49 -07:00
if err != nil {
slog.Error("couldn't decode ggml", "error", err)
continue
}
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
2024-06-17 10:38:55 -07:00
errs = append(errs, errCapabilityCompletion)
2024-06-11 14:03:42 -07:00
}
2024-06-20 13:45:47 -07:00
case CapabilityTools:
if !slices.Contains(m.Template.Vars(), "tools") {
errs = append(errs, errCapabilityTools)
}
case CapabilityInsert:
vars := m.Template.Vars()
if !slices.Contains(vars, "suffix") {
errs = append(errs, errCapabilityInsert)
2024-06-20 13:45:47 -07:00
}
2024-06-11 14:03:42 -07:00
default:
slog.Error("unknown capability", "capability", cap)
2024-06-17 10:38:55 -07:00
return fmt.Errorf("unknown capability: %s", cap)
2024-06-11 14:03:42 -07:00
}
}
2024-06-17 10:38:55 -07:00
if err := errors.Join(errs...); err != nil {
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
2024-06-17 10:38:55 -07:00
}
return nil
}
2024-04-30 10:55:19 -07:00
func (m *Model) String() string {
var modelfile parser.Modelfile
2024-04-30 10:55:19 -07:00
modelfile.Commands = append(modelfile.Commands, parser.Command{
2024-04-30 10:55:19 -07:00
Name: "model",
Args: m.ModelPath,
})
2024-05-08 12:42:48 -07:00
for _, adapter := range m.AdapterPaths {
modelfile.Commands = append(modelfile.Commands, parser.Command{
2024-05-08 12:42:48 -07:00
Name: "adapter",
Args: adapter,
2024-04-30 10:55:19 -07:00
})
}
2024-05-08 12:42:48 -07:00
for _, projector := range m.ProjectorPaths {
modelfile.Commands = append(modelfile.Commands, parser.Command{
2024-05-08 12:42:48 -07:00
Name: "model",
Args: projector,
2024-04-30 10:55:19 -07:00
})
}
2024-06-10 14:54:42 -07:00
if m.Template != nil {
modelfile.Commands = append(modelfile.Commands, parser.Command{
2024-05-08 12:42:48 -07:00
Name: "template",
2024-06-10 14:54:42 -07:00
Args: m.Template.String(),
2024-04-30 10:55:19 -07:00
})
}
2024-05-08 12:42:48 -07:00
if m.System != "" {
modelfile.Commands = append(modelfile.Commands, parser.Command{
2024-05-08 12:42:48 -07:00
Name: "system",
Args: m.System,
2024-04-30 10:55:19 -07:00
})
}
for k, v := range m.Options {
switch v := v.(type) {
case []any:
for _, s := range v {
modelfile.Commands = append(modelfile.Commands, parser.Command{
2024-04-30 10:55:19 -07:00
Name: k,
Args: fmt.Sprintf("%v", s),
})
}
default:
modelfile.Commands = append(modelfile.Commands, parser.Command{
2024-04-30 10:55:19 -07:00
Name: k,
Args: fmt.Sprintf("%v", v),
})
}
}
for _, license := range m.License {
modelfile.Commands = append(modelfile.Commands, parser.Command{
2024-04-30 10:55:19 -07:00
Name: "license",
Args: license,
})
}
for _, msg := range m.Messages {
modelfile.Commands = append(modelfile.Commands, parser.Command{
2024-04-30 10:55:19 -07:00
Name: "message",
2024-07-31 16:52:09 -07:00
Args: fmt.Sprintf("%s: %s", msg.Role, msg.Content),
2024-04-30 10:55:19 -07:00
})
}
2024-04-30 10:55:19 -07:00
return modelfile.String()
}
type ConfigV2 struct {
ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"`
ModelFamilies []string `json:"model_families"`
ModelType string `json:"model_type"`
FileType string `json:"file_type"`
2023-07-21 13:33:56 -07:00
// required by spec
Architecture string `json:"architecture"`
OS string `json:"os"`
2023-12-01 11:37:17 -08:00
RootFS RootFS `json:"rootfs"`
}
type RootFS struct {
Type string `json:"type"`
DiffIDs []string `json:"diff_ids"`
}
2024-06-10 08:47:13 -07:00
func GetManifest(mp ModelPath) (*Manifest, string, error) {
fp, err := mp.GetManifestPath()
if err != nil {
2023-08-28 20:50:24 -07:00
return nil, "", err
}
2024-08-14 14:37:51 -07:00
f, err := os.Open(fp)
if err != nil {
2024-08-14 14:37:51 -07:00
return nil, "", err
}
2024-08-14 14:37:51 -07:00
defer f.Close()
2024-08-14 14:37:51 -07:00
sha256sum := sha256.New()
2023-08-28 20:50:24 -07:00
2024-08-14 14:37:51 -07:00
var manifest Manifest
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&manifest); err != nil {
2023-08-28 20:50:24 -07:00
return nil, "", err
}
2024-08-14 14:37:51 -07:00
return &manifest, hex.EncodeToString(sha256sum.Sum(nil)), nil
}
func GetModel(name string) (*Model, error) {
mp := ParseModelPath(name)
2023-08-28 20:50:24 -07:00
manifest, digest, err := GetManifest(mp)
if err != nil {
return nil, err
}
model := &Model{
Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(),
Digest: digest,
2024-06-10 14:54:42 -07:00
Template: template.DefaultTemplate,
}
if manifest.Config.Digest != "" {
filename, err := GetBlobsPath(manifest.Config.Digest)
if err != nil {
return nil, err
}
2023-12-01 11:37:17 -08:00
configFile, err := os.Open(filename)
if err != nil {
return nil, err
}
defer configFile.Close()
2023-12-01 11:37:17 -08:00
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
return nil, err
}
2023-12-01 11:37:17 -08:00
}
for _, layer := range manifest.Layers {
2023-07-17 22:44:21 -07:00
filename, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
switch layer.MediaType {
case "application/vnd.ollama.image.model":
model.ModelPath = filename
2024-01-25 12:12:36 -08:00
model.ParentModel = layer.From
2023-08-04 18:56:40 -04:00
case "application/vnd.ollama.image.embed":
// Deprecated in versions > 0.1.2
// TODO: remove this warning in a future version
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
case "application/vnd.ollama.image.adapter":
model.AdapterPaths = append(model.AdapterPaths, filename)
2023-11-30 10:30:23 -08:00
case "application/vnd.ollama.image.projector":
model.ProjectorPaths = append(model.ProjectorPaths, filename)
2024-06-10 14:54:42 -07:00
case "application/vnd.ollama.image.prompt",
"application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
2024-06-10 14:54:42 -07:00
model.Template, err = template.Parse(string(bts))
if err != nil {
return nil, err
}
2024-06-10 14:54:42 -07:00
case "application/vnd.ollama.image.system":
bts, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
2024-06-10 14:54:42 -07:00
model.System = string(bts)
case "application/vnd.ollama.image.params":
2023-07-17 12:08:10 -07:00
params, err := os.Open(filename)
if err != nil {
return nil, err
}
defer params.Close()
// parse model options parameters into a map so that we can see which fields have been specified explicitly
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
return nil, err
}
2024-01-25 12:12:36 -08:00
case "application/vnd.ollama.image.messages":
msgs, err := os.Open(filename)
if err != nil {
return nil, err
}
defer msgs.Close()
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
return nil, err
}
2023-09-06 11:04:17 -07:00
case "application/vnd.ollama.image.license":
bts, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
model.License = append(model.License, string(bts))
}
}
return model, nil
}
2024-04-16 16:22:38 -07:00
func CopyModel(src, dst model.Name) error {
if !dst.IsFullyQualified() {
return model.Unqualified(dst)
}
if !src.IsFullyQualified() {
return model.Unqualified(src)
}
2024-04-28 23:47:49 -04:00
if src.Filepath() == dst.Filepath() {
return nil
}
2024-04-16 16:22:38 -07:00
manifests, err := GetManifestPath()
2023-08-21 21:56:56 -07:00
if err != nil {
return err
}
dstpath := filepath.Join(manifests, dst.Filepath())
2024-04-16 16:22:38 -07:00
if err := os.MkdirAll(filepath.Dir(dstpath), 0o755); err != nil {
return err
}
2023-07-24 11:27:28 -04:00
srcpath := filepath.Join(manifests, src.Filepath())
2024-04-16 16:22:38 -07:00
srcfile, err := os.Open(srcpath)
2023-07-24 11:27:28 -04:00
if err != nil {
return err
}
2024-04-16 16:22:38 -07:00
defer srcfile.Close()
2023-07-24 11:27:28 -04:00
2024-04-16 16:22:38 -07:00
dstfile, err := os.Create(dstpath)
2023-07-24 11:27:28 -04:00
if err != nil {
return err
}
2024-04-16 16:22:38 -07:00
defer dstfile.Close()
2023-07-24 11:27:28 -04:00
2024-04-16 16:22:38 -07:00
_, err = io.Copy(dstfile, srcfile)
return err
2023-07-24 11:27:28 -04:00
}
2024-08-14 16:36:07 -07:00
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
// Ignore corrupt manifests to avoid blocking deletion of layers that are freshly orphaned
manifests, err := Manifests(true)
2023-07-20 16:09:23 -07:00
if err != nil {
return err
}
2023-08-30 14:31:12 -04:00
2024-08-14 16:36:07 -07:00
for _, manifest := range manifests {
2023-08-30 14:31:12 -04:00
for _, layer := range manifest.Layers {
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
2023-07-31 15:26:18 -07:00
}
2023-07-20 16:09:23 -07:00
// only delete the files which are still in the deleteMap
2023-11-14 12:30:34 -08:00
for k := range deleteMap {
fp, err := GetBlobsPath(k)
if err != nil {
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
2023-11-14 12:30:34 -08:00
continue
}
2024-05-09 16:35:20 -07:00
if err := os.Remove(fp); err != nil {
slog.Info(fmt.Sprintf("couldn't remove file '%s': %v", fp, err))
continue
2023-07-20 16:09:23 -07:00
}
}
return nil
}
func PruneLayers() error {
2023-11-14 12:30:34 -08:00
deleteMap := make(map[string]struct{})
p, err := GetBlobsPath("")
if err != nil {
return err
}
blobs, err := os.ReadDir(p)
if err != nil {
slog.Info(fmt.Sprintf("couldn't read dir '%s': %v", p, err))
return err
}
for _, blob := range blobs {
name := blob.Name()
name = strings.ReplaceAll(name, "-", ":")
2024-05-09 16:35:20 -07:00
_, err := GetBlobsPath(name)
if err != nil {
if errors.Is(err, ErrInvalidDigestFormat) {
// remove invalid blobs (e.g. partial downloads)
if err := os.Remove(filepath.Join(p, blob.Name())); err != nil {
slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err)
}
}
continue
2023-11-14 14:27:51 -08:00
}
2024-05-09 16:35:20 -07:00
deleteMap[name] = struct{}{}
}
slog.Info(fmt.Sprintf("total blobs: %d", len(deleteMap)))
2024-08-14 16:36:07 -07:00
if err := deleteUnusedLayers(deleteMap); err != nil {
slog.Error(fmt.Sprintf("couldn't remove unused layers: %v", err))
return nil
}
slog.Info(fmt.Sprintf("total unused blobs removed: %d", len(deleteMap)))
return nil
}
2023-09-26 17:28:14 -07:00
func PruneDirectory(path string) error {
info, err := os.Lstat(path)
if err != nil {
return err
}
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
entries, err := os.ReadDir(path)
if err != nil {
return err
}
for _, entry := range entries {
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
return err
}
}
entries, err = os.ReadDir(path)
if err != nil {
return err
}
if len(entries) > 0 {
return nil
}
return os.Remove(path)
}
return nil
}
2024-02-14 11:29:49 -08:00
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
2023-07-18 18:51:30 -07:00
fn(api.ProgressResponse{Status: "retrieving manifest"})
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
2024-08-01 14:52:15 -07:00
return errors.New("insecure protocol http")
}
2023-08-28 20:50:24 -07:00
manifest, _, err := GetManifest(mp)
if err != nil {
2023-07-18 18:51:30 -07:00
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
return err
}
var layers []Layer
2023-07-31 21:37:40 -04:00
layers = append(layers, manifest.Layers...)
if manifest.Config.Digest != "" {
layers = append(layers, manifest.Config)
}
for _, layer := range layers {
2023-10-09 10:24:27 -07:00
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
return err
}
2023-07-18 18:51:30 -07:00
}
fn(api.ProgressResponse{Status: "pushing manifest"})
2023-08-21 18:38:31 -07:00
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
2023-08-21 18:24:42 -07:00
headers := make(http.Header)
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
2023-11-02 13:10:58 -07:00
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
if err != nil {
return err
}
defer resp.Body.Close()
fn(api.ProgressResponse{Status: "success"})
return nil
}
2024-02-14 11:29:49 -08:00
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
// build deleteMap to prune unused layers
2023-11-14 12:30:34 -08:00
deleteMap := make(map[string]struct{})
2024-08-14 14:37:51 -07:00
manifest, _, err := GetManifest(mp)
if errors.Is(err, os.ErrNotExist) {
// noop
} else if err != nil {
slog.Warn("pulling model with bad existing manifest", "name", name, "error", err)
2024-08-14 14:37:51 -07:00
} else {
for _, l := range manifest.Layers {
deleteMap[l.Digest] = struct{}{}
}
2024-08-14 14:37:51 -07:00
if manifest.Config.Digest != "" {
deleteMap[manifest.Config.Digest] = struct{}{}
}
}
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
2024-08-01 14:52:15 -07:00
return errors.New("insecure protocol http")
2023-08-21 21:56:56 -07:00
}
2023-07-18 18:51:30 -07:00
fn(api.ProgressResponse{Status: "pulling manifest"})
manifest, err = pullModelManifest(ctx, mp, regOpts)
if err != nil {
return fmt.Errorf("pull model manifest: %s", err)
}
var layers []Layer
2023-07-20 20:18:00 +02:00
layers = append(layers, manifest.Layers...)
if manifest.Config.Digest != "" {
layers = append(layers, manifest.Config)
}
skipVerify := make(map[string]bool)
for _, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{
mp: mp,
digest: layer.Digest,
regOpts: regOpts,
fn: fn,
})
if err != nil {
return err
}
skipVerify[layer.Digest] = cacheHit
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
2023-07-20 11:44:05 -07:00
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers {
if skipVerify[layer.Digest] {
continue
}
2023-07-20 11:44:05 -07:00
if err := verifyBlob(layer.Digest); err != nil {
2023-07-24 14:53:01 -04:00
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
2023-07-24 14:53:01 -04:00
}
}
2023-07-20 11:44:05 -07:00
return err
}
}
2023-07-18 18:51:30 -07:00
fn(api.ProgressResponse{Status: "writing manifest"})
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
return err
}
2023-07-20 20:18:00 +02:00
err = os.WriteFile(fp, manifestJSON, 0o644)
if err != nil {
slog.Info(fmt.Sprintf("couldn't write to %s", fp))
return err
}
2024-08-14 14:37:51 -07:00
if !envconfig.NoPrune() && len(deleteMap) > 0 {
fn(api.ProgressResponse{Status: "removing unused layers"})
2024-08-14 16:36:07 -07:00
if err := deleteUnusedLayers(deleteMap); err != nil {
fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't remove unused layers: %v", err)})
}
}
2023-07-18 18:51:30 -07:00
fn(api.ProgressResponse{Status: "success"})
return nil
}
2024-06-10 08:47:13 -07:00
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
2023-08-21 18:38:31 -07:00
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
2023-08-21 18:24:42 -07:00
headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
2023-11-02 13:13:32 -07:00
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
if err != nil {
return nil, err
}
defer resp.Body.Close()
2024-08-14 14:37:51 -07:00
var m Manifest
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return nil, err
}
2024-08-14 14:37:51 -07:00
return &m, err
}
// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
2023-09-28 10:00:34 -07:00
func GetSHA256Digest(r io.Reader) (string, int64) {
2023-07-18 17:14:12 -07:00
h := sha256.New()
n, err := io.Copy(h, r)
if err != nil {
log.Fatal(err)
}
2023-09-28 10:00:34 -07:00
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
2024-08-01 14:52:15 -07:00
var errUnauthorized = errors.New("unauthorized: access denied")
2024-02-14 11:29:49 -08:00
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
2024-05-21 22:21:04 -07:00
for range 2 {
2024-02-14 11:29:49 -08:00
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
2023-08-17 12:35:29 -07:00
if err != nil {
if !errors.Is(err, context.Canceled) {
slog.Info(fmt.Sprintf("request failed: %v", err))
}
2023-08-17 12:35:29 -07:00
return nil, err
}
switch {
case resp.StatusCode == http.StatusUnauthorized:
resp.Body.Close()
// Handle authentication error with one retry
2024-02-14 11:29:49 -08:00
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
token, err := getAuthorizationToken(ctx, challenge)
2023-08-17 12:35:29 -07:00
if err != nil {
return nil, err
}
regOpts.Token = token
if body != nil {
_, err = body.Seek(0, io.SeekStart)
if err != nil {
return nil, err
}
}
case resp.StatusCode == http.StatusNotFound:
resp.Body.Close()
return nil, os.ErrNotExist
case resp.StatusCode >= http.StatusBadRequest:
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
}
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
default:
return resp, nil
2023-08-17 12:35:29 -07:00
}
}
return nil, errUnauthorized
2023-08-17 12:35:29 -07:00
}
// testMakeRequestDialContext specifies the dial function for the http client in
// makeRequest. It can be used to resolve hosts in model names to local
// addresses for testing. For example, the model name ("example.com/my/model")
// can be directed to push/pull from "127.0.0.1:1234".
//
// This is not safe to set across goroutines. It should be set in
// the main test goroutine, and not by tests marked to run in parallel with
// t.Parallel().
//
// It should be cleared after use, otherwise it will affect other tests.
//
// Ideally we would have some set this up the stack, but the code is not
// structured in a way that makes this easy, so this will have to do for now.
var testMakeRequestDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
2024-02-14 11:29:49 -08:00
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) {
if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
requestURL.Scheme = "http"
}
req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
if err != nil {
return nil, err
}
if headers != nil {
req.Header = headers
}
if regOpts != nil {
if regOpts.Token != "" {
req.Header.Set("Authorization", "Bearer "+regOpts.Token)
} else if regOpts.Username != "" && regOpts.Password != "" {
req.SetBasicAuth(regOpts.Username, regOpts.Password)
}
}
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
2024-02-14 11:29:49 -08:00
if s := req.Header.Get("Content-Length"); s != "" {
contentLength, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return nil, err
}
req.ContentLength = contentLength
}
c := &http.Client{
CheckRedirect: regOpts.CheckRedirect,
2024-02-14 11:29:49 -08:00
}
if testMakeRequestDialContext != nil {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DialContext = testMakeRequestDialContext
c.Transport = tr
}
return c.Do(req)
2024-02-14 11:29:49 -08:00
}
2023-08-10 11:34:25 -07:00
func getValue(header, key string) string {
startIdx := strings.Index(header, key+"=")
if startIdx == -1 {
return ""
}
// Move the index to the starting quote after the key.
startIdx += len(key) + 2
endIdx := startIdx
for endIdx < len(header) {
if header[endIdx] == '"' {
if endIdx+1 < len(header) && header[endIdx+1] != ',' { // If the next character isn't a comma, continue
endIdx++
continue
}
break
}
endIdx++
}
return header[startIdx:endIdx]
}
2024-02-14 11:29:49 -08:00
func parseRegistryChallenge(authStr string) registryChallenge {
2023-08-10 11:34:25 -07:00
authStr = strings.TrimPrefix(authStr, "Bearer ")
2024-02-14 11:29:49 -08:00
return registryChallenge{
2023-08-10 11:34:25 -07:00
Realm: getValue(authStr, "realm"),
Service: getValue(authStr, "service"),
Scope: getValue(authStr, "scope"),
}
}
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
2023-07-24 14:53:01 -04:00
2023-07-20 11:44:05 -07:00
func verifyBlob(digest string) error {
fp, err := GetBlobsPath(digest)
if err != nil {
return err
}
f, err := os.Open(fp)
if err != nil {
return err
}
defer f.Close()
fileDigest, _ := GetSHA256Digest(f)
if digest != fileDigest {
2023-07-24 14:53:01 -04:00
return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest)
2023-07-20 11:44:05 -07:00
}
return nil
}