From 86f3c1c3b94fe26c5fec6322f1ee4ab6c9ab2b47 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Thu, 13 Jul 2023 16:59:46 -0700 Subject: [PATCH] basic distribution w/ push/pull --- api/client.go | 17 +- api/types.go | 47 +++- cmd/cmd.go | 20 +- server/images.go | 678 +++++++++++++++++++++++++++++++++++++++++++++++ server/models.go | 5 + server/routes.go | 131 +++++---- 6 files changed, 818 insertions(+), 80 deletions(-) create mode 100644 server/images.go diff --git a/api/client.go b/api/client.go index 29ab26983..b58e53a96 100644 --- a/api/client.go +++ b/api/client.go @@ -107,12 +107,15 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate type PullProgressFunc func(PullProgress) error func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error { - return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error { - var resp PullProgress - if err := json.Unmarshal(bts, &resp); err != nil { - return err - } + /* + return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error { + var resp PullProgress + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } - return fn(resp) - }) + return fn(resp) + }) + */ + return nil } diff --git a/api/types.go b/api/types.go index 86d116f21..af6e0e411 100644 --- a/api/types.go +++ b/api/types.go @@ -7,16 +7,6 @@ import ( "time" ) -type PullRequest struct { - Model string `json:"model"` -} - -type PullProgress struct { - Total int64 `json:"total"` - Completed int64 `json:"completed"` - Percent float64 `json:"percent"` -} - type GenerateRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` @@ -25,6 +15,43 @@ type GenerateRequest struct { Options `json:"options"` } +type CreateRequest struct { + Name string `json:"name"` + Path string `json:"path"` +} + +type CreateProgress struct { + Status string `json:"status"` +} + +type PullRequest struct { + Name string `json:"name"` + Username string `json:"username"` + Password string `json:"password"` +} + +type PullProgress struct { + Status string `json:"status"` + Digest string `json:"digest,omitempty"` + Total int `json:"total,omitempty"` + Completed int `json:"completed,omitempty"` + Percent float64 `json:"percent,omitempty"` +} + +type PushRequest struct { + Name string `json:"name"` + Username string `json:"username"` + Password string `json:"password"` +} + +type PushProgress struct { + Status string `json:"status"` + Digest string `json:"digest,omitempty"` + Total int `json:"total,omitempty"` + Completed int `json:"completed,omitempty"` + Percent float64 `json:"percent,omitempty"` +} + type GenerateResponse struct { Model string `json:"model"` CreatedAt time.Time `json:"created_at"` diff --git a/cmd/cmd.go b/cmd/cmd.go index 18a90b9a8..a761dd04e 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -52,24 +52,8 @@ func RunRun(cmd *cobra.Command, args []string) error { } func pull(model string) error { - client := api.NewClient() - var bar *progressbar.ProgressBar - return client.Pull( - context.Background(), - &api.PullRequest{Model: model}, - func(progress api.PullProgress) error { - if bar == nil { - if progress.Percent >= 100 { - // already downloaded - return nil - } - - bar = progressbar.DefaultBytes(progress.Total) - } - - return bar.Set64(progress.Completed) - }, - ) + // TODO add this back + return nil } func RunGenerate(cmd *cobra.Command, args []string) error { diff --git a/server/images.go b/server/images.go new file mode 100644 index 000000000..eca185617 --- /dev/null +++ b/server/images.go @@ -0,0 +1,678 @@ +package server + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "log" + "net/http" + "os" + "path" + "strings" + + "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/parser" +) + +var DefaultRegistry string = "http://localhost:6000" + +type ManifestV2 struct { + SchemaVersion int `json:"schemaVersion"` + MediaType string `json:"mediaType"` + Config Layer `json:"config"` + Layers []*Layer `json:"layers"` +} + +type Layer struct { + MediaType string `json:"mediaType"` + Digest string `json:"digest"` + Size int `json:"size"` +} + +type LayerWithBuffer struct { + Layer + + Buffer *bytes.Buffer +} + +type ConfigV2 struct { + Architecture string `json:"architecture"` + OS string `json:"os"` + RootFS RootFS `json:"rootfs"` +} + +type RootFS struct { + Type string `json:"type"` + DiffIDs []string `json:"diff_ids"` +} + +func GetManifest(name string) (*ManifestV2, error) { + home, err := os.UserHomeDir() + if err != nil { + return nil, err + } + + filepath := path.Join(home, ".ollama/models/manifests", name) + _, err = os.Stat(filepath) + if os.IsNotExist(err) { + return nil, fmt.Errorf("couldn't find model '%s'", name) + } + + var manifest *ManifestV2 + + f, err := os.Open(filepath) + if err != nil { + return nil, fmt.Errorf("couldn't open file '%s'", filepath) + } + + decoder := json.NewDecoder(f) + err = decoder.Decode(&manifest) + if err != nil { + return nil, err + } + + return manifest, nil +} + +func GetModel(name string) (*Model, error) { + home, err := os.UserHomeDir() + if err != nil { + return nil, err + } + + manifest, err := GetManifest(name) + if err != nil { + return nil, err + } + + model := &Model{ + Name: name, + } + + for _, layer := range manifest.Layers { + filename := path.Join(home, ".ollama/models/blobs", layer.Digest) + switch layer.MediaType { + case "application/vnd.ollama.image.model": + model.ModelPath = filename + case "application/vnd.ollama.image.prompt": + f, err := os.Open(filename) + if err != nil { + return nil, err + } + data, err := ioutil.ReadAll(f) + if err != nil { + return nil, err + } + model.Prompt = string(data) + case "application/vnd.ollama.image.params": + /* + f, err = os.Open(filename) + if err != nil { + return nil, err + } + */ + + var opts api.Options + /* + decoder = json.NewDecoder(f) + err = decoder.Decode(&opts) + if err != nil { + return nil, err + } + */ + model.Options = opts + } + } + + return model, nil +} + +func CreateModel(name string, mf io.Reader, fn func(status string)) error { + fn("parsing modelfile") + commands, err := parser.Parse(mf) + if err != nil { + return err + } + + var layers []*LayerWithBuffer + var param map[string]string + param = make(map[string]string) + + for _, c := range commands { + log.Printf("[%s] - %s\n", c.Name, c.Arg) + switch c.Name { + case "model": + fn("creating model layer") + file, err := os.Open(c.Arg) + if err != nil { + return fmt.Errorf("failed to open file: %v", err) + } + defer file.Close() + + l, err := CreateLayer(file) + if err != nil { + return fmt.Errorf("failed to create layer: %v", err) + } + l.MediaType = "application/vnd.ollama.image.model" + layers = append(layers, l) + case "prompt": + fn("creating prompt layer") + prompt := strings.NewReader(c.Arg) + l, err := CreateLayer(prompt) + if err != nil { + return fmt.Errorf("failed to create layer: %v", err) + } + l.MediaType = "application/vnd.ollama.image.prompt" + layers = append(layers, l) + default: + param[c.Name] = c.Arg + } + } + + // Create a single layer for the parameters + fn("creating parameter layer") + paramData, err := paramsToReader(param) + if err != nil { + return fmt.Errorf("couldn't create params json: %v", err) + } + l, err := CreateLayer(paramData) + if err != nil { + return fmt.Errorf("failed to create layer: %v", err) + } + l.MediaType = "application/vnd.ollama.image.params" + layers = append(layers, l) + + digests, err := getLayerDigests(layers) + if err != nil { + return err + } + + // Create a layer for the config object + fn("creating config layer") + cfg, err := createConfigLayer(digests) + if err != nil { + return err + } + layers = append(layers, cfg) + + home, err := os.UserHomeDir() + if err != nil { + return err + } + + var manifestLayers []*Layer + + // Write each of the layers to disk + for _, layer := range layers { + filepath := path.Join(home, ".ollama/models/blobs", layer.Digest) + + // TODO add a force flag to always write out the layers + + _, err = os.Stat(filepath) + if os.IsNotExist(err) { + fn(fmt.Sprintf("writing layer %s", layer.Digest)) + out, err := os.Create(filepath) + if err != nil { + log.Printf("couldn't create %s", filepath) + return err + } + defer out.Close() + + _, err = io.Copy(out, layer.Buffer) + if err != nil { + return err + } + } else { + fn(fmt.Sprintf("using already created layer %s", layer.Digest)) + } + + if layer.MediaType == "application/vnd.docker.container.image.v1+json" { + continue + } + + manifestLayer := &Layer{ + MediaType: layer.MediaType, + Size: layer.Size, + Digest: layer.Digest, + } + + manifestLayers = append(manifestLayers, manifestLayer) + } + + // Create the manifest + fn("writing manifest") + manifest := ManifestV2{ + SchemaVersion: 2, + MediaType: "application/vnd.docker.distribution.manifest.v2+json", + Config: Layer{ + MediaType: cfg.MediaType, + Size: cfg.Size, + Digest: cfg.Digest, + }, + Layers: manifestLayers, + } + + manifestJSON, err := json.Marshal(manifest) + if err != nil { + return err + } + + filepath := path.Join(home, ".ollama/models/manifests", name) + err = os.WriteFile(filepath, manifestJSON, 0644) + if err != nil { + log.Printf("couldn't write to %s", filepath) + return err + } + + fn("success") + return nil +} + +func paramsToReader(m map[string]string) (io.Reader, error) { + data, err := json.MarshalIndent(m, "", " ") + if err != nil { + return nil, err + } + + return strings.NewReader(string(data)), nil +} + +func getLayerDigests(layers []*LayerWithBuffer) ([]string, error) { + var digests []string + for _, l := range layers { + if l.Digest == "" { + return nil, fmt.Errorf("layer is missing a digest!") + } + digests = append(digests, l.Digest) + } + return digests, nil +} + +// CreateLayer creates a Layer object from a given file +func CreateLayer(f io.Reader) (*LayerWithBuffer, error) { + buf := new(bytes.Buffer) + _, err := io.Copy(buf, f) + if err != nil { + return nil, err + } + + digest, size := GetSHA256Digest(buf) + + layer := &LayerWithBuffer{ + Layer: Layer{ + MediaType: "application/vnd.docker.image.rootfs.diff.tar", + Digest: digest, + Size: size, + }, + Buffer: buf, + } + + return layer, nil +} + +func PushModel(name, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error { + fn("retrieving manifest", "", 0, 0, 0) + manifest, err := GetManifest(name) + if err != nil { + fn("couldn't retrieve manifest", "", 0, 0, 0) + return err + } + + var repoName string + var tag string + + comps := strings.Split(name, ":") + switch { + case len(comps) < 1 || len(comps) > 2: + return fmt.Errorf("repository name was invalid") + case len(comps) == 1: + repoName = comps[0] + tag = "latest" + case len(comps) == 2: + repoName = comps[0] + tag = comps[1] + } + + var layers []*Layer + var total int + var completed int + for _, layer := range manifest.Layers { + layers = append(layers, layer) + total += layer.Size + } + layers = append(layers, &manifest.Config) + total += manifest.Config.Size + + for _, layer := range layers { + exists, err := checkBlobExistence(DefaultRegistry, repoName, layer.Digest, username, password) + if err != nil { + return err + } + + if exists { + completed += layer.Size + fn("using existing layer", layer.Digest, total, completed, float64(completed)/float64(total)) + continue + } + + fn("starting upload", layer.Digest, total, completed, float64(completed)/float64(total)) + + location, err := startUpload(DefaultRegistry, repoName, username, password) + if err != nil { + log.Printf("couldn't start upload: %v", err) + return err + } + + err = uploadBlob(location, layer, username, password) + if err != nil { + log.Printf("error uploading blob: %v", err) + return err + } + completed += layer.Size + fn("upload complete", layer.Digest, total, completed, float64(completed)/float64(total)) + } + + fn("pushing manifest", "", total, completed, float64(completed/total)) + url := fmt.Sprintf("%s/v2/%s/manifests/%s", DefaultRegistry, repoName, tag) + headers := map[string]string{ + "Content-Type": "application/vnd.docker.distribution.manifest.v2+json", + } + + manifestJSON, err := json.Marshal(manifest) + if err != nil { + return err + } + + resp, err := makeRequest("PUT", url, headers, bytes.NewReader(manifestJSON), username, password) + defer resp.Body.Close() + + // Check for success: For a successful upload, the Docker registry will respond with a 201 Created + if resp.StatusCode != http.StatusCreated { + body, _ := ioutil.ReadAll(resp.Body) + return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body)) + } + + fn("success", "", total, completed, 1.0) + + return nil +} + +func PullModel(name, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error { + var repoName string + var tag string + + comps := strings.Split(name, ":") + switch { + case len(comps) < 1 || len(comps) > 2: + return fmt.Errorf("repository name was invalid") + case len(comps) == 1: + repoName = comps[0] + tag = "latest" + case len(comps) == 2: + repoName = comps[0] + tag = comps[1] + } + + fn("pulling manifest", "", 0, 0, 0) + + manifest, err := pullModelManifest(DefaultRegistry, repoName, tag, username, password) + if err != nil { + fmt.Errorf("Error: %q", err) + return err + } + + var layers []*Layer + var total int + var completed int + for _, layer := range manifest.Layers { + layers = append(layers, layer) + total += layer.Size + } + layers = append(layers, &manifest.Config) + total += manifest.Config.Size + + for _, layer := range layers { + fn("starting download", layer.Digest, total, completed, float64(completed)/float64(total)) + if err := downloadBlob(DefaultRegistry, repoName, layer.Digest, username, password); err != nil { + return err + } + completed += layer.Size + fn("download complete", layer.Digest, total, completed, float64(completed)/float64(total)) + } + + fn("writing manifest", "", total, completed, 1.0) + + home, err := os.UserHomeDir() + if err != nil { + return err + } + + manifestJSON, err := json.Marshal(manifest) + if err != nil { + return err + } + + filepath := path.Join(home, ".ollama/models/manifests", name) + err = os.WriteFile(filepath, manifestJSON, 0644) + if err != nil { + log.Printf("couldn't write to %s", filepath) + return err + } + + fn("success", "", total, completed, 1.0) + + return nil +} + +func pullModelManifest(registryURL, repoName, tag, username, password string) (*ManifestV2, error) { + url := fmt.Sprintf("%s/v2/%s/manifests/%s", registryURL, repoName, tag) + headers := map[string]string{ + "Accept": "application/vnd.docker.distribution.manifest.v2+json", + } + + resp, err := makeRequest("GET", url, headers, nil, username, password) + defer resp.Body.Close() + + // Check for success: For a successful upload, the Docker registry will respond with a 201 Created + if resp.StatusCode != http.StatusOK { + body, _ := ioutil.ReadAll(resp.Body) + return nil, fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body)) + } + + var m *ManifestV2 + if err := json.NewDecoder(resp.Body).Decode(&m); err != nil { + return nil, err + } + + return m, err +} + +func createConfigLayer(layers []string) (*LayerWithBuffer, error) { + // TODO change architecture and OS + config := ConfigV2{ + Architecture: "arm64", + OS: "linux", + RootFS: RootFS{ + Type: "layers", + DiffIDs: layers, + }, + } + + configJSON, err := json.Marshal(config) + if err != nil { + return nil, err + } + + buf := bytes.NewBuffer(configJSON) + digest, size := GetSHA256Digest(buf) + + layer := &LayerWithBuffer{ + Layer: Layer{ + MediaType: "application/vnd.docker.container.image.v1+json", + Digest: digest, + Size: size, + }, + Buffer: buf, + } + return layer, nil +} + +// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer +func GetSHA256Digest(data *bytes.Buffer) (string, int) { + layerBytes := data.Bytes() + hash := sha256.Sum256(layerBytes) + return "sha256:" + hex.EncodeToString(hash[:]), len(layerBytes) +} + +func startUpload(registryURL string, repositoryName string, username string, password string) (string, error) { + url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", registryURL, repositoryName) + + resp, err := makeRequest("POST", url, nil, nil, username, password) + defer resp.Body.Close() + + if err != nil { + return "", fmt.Errorf("failed to create request: %v", err) + } + + // Check for success + if resp.StatusCode != http.StatusAccepted { + body, _ := ioutil.ReadAll(resp.Body) + return "", fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body)) + } + + // Extract UUID location from header + location := resp.Header.Get("Location") + if location == "" { + return "", fmt.Errorf("Location header is missing in response") + } + + return location, nil +} + +// Function to check if a blob already exists in the Docker registry +func checkBlobExistence(registryURL string, repositoryName string, digest string, username string, password string) (bool, error) { + url := fmt.Sprintf("%s/v2/%s/blobs/%s", registryURL, repositoryName, digest) + + resp, err := makeRequest("HEAD", url, nil, nil, username, password) + defer resp.Body.Close() + + if err != nil { + return false, fmt.Errorf("failed to create request: %v", err) + } + + // Check for success: If the blob exists, the Docker registry will respond with a 200 OK + return resp.StatusCode == http.StatusOK, nil +} + +func uploadBlob(location string, layer *Layer, username string, password string) error { + home, err := os.UserHomeDir() + if err != nil { + return err + } + + // Create URL + url := fmt.Sprintf("%s&digest=%s", location, layer.Digest) + + headers := make(map[string]string) + headers["Content-Length"] = fmt.Sprintf("%d", layer.Size) + headers["Content-Type"] = "application/octet-stream" + + // TODO change from monolithic uploads to chunked uploads + // TODO allow resumability + // TODO allow canceling uploads via DELETE + // TODO allow cross repo blob mount + + filepath := path.Join(home, ".ollama/models/blobs", layer.Digest) + f, err := os.Open(filepath) + if err != nil { + return err + } + + resp, err := makeRequest("PUT", url, headers, f, username, password) + defer resp.Body.Close() + + if err != nil { + return err + } + + // Check for success: For a successful upload, the Docker registry will respond with a 201 Created + if resp.StatusCode != http.StatusCreated { + body, _ := ioutil.ReadAll(resp.Body) + return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body)) + } + + return nil +} + +func downloadBlob(registryURL, repoName, digest, username, password string) error { + home, err := os.UserHomeDir() + if err != nil { + return err + } + + filepath := path.Join(home, ".ollama/models/blobs", digest) + + _, err = os.Stat(filepath) + if !os.IsNotExist(err) { + // we already have the file, so return + log.Printf("already have %s\n", digest) + return nil + } + + url := fmt.Sprintf("%s/v2/%s/blobs/%s", registryURL, repoName, digest) + headers := map[string]string{} + + resp, err := makeRequest("GET", url, headers, nil, username, password) + defer resp.Body.Close() + + // TODO: handle 307 redirects + // TODO: handle range requests to make this resumable + + if resp.StatusCode != http.StatusOK { + body, _ := ioutil.ReadAll(resp.Body) + return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body)) + } + + out, err := os.Create(filepath) + if err != nil { + log.Printf("couldn't create %s", filepath) + return err + } + defer out.Close() + + _, err = io.Copy(out, resp.Body) + if err != nil { + return err + } + + log.Printf("success getting %s\n", digest) + return nil +} + +func makeRequest(method, url string, headers map[string]string, body io.Reader, username, password string) (*http.Response, error) { + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, err + } + + for k, v := range headers { + req.Header.Set(k, v) + } + + // TODO: better auth + if username != "" && password != "" { + req.SetBasicAuth(username, password) + } + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + + return resp, nil +} diff --git a/server/models.go b/server/models.go index de46e96f9..dee2adde3 100644 --- a/server/models.go +++ b/server/models.go @@ -9,12 +9,17 @@ import ( "os" "path/filepath" "strconv" + + "github.com/jmorganca/ollama/api" ) const directoryURL = "https://ollama.ai/api/models" type Model struct { Name string `json:"name"` + ModelPath string + Prompt string + Options api.Options DisplayName string `json:"display_name"` Parameters string `json:"parameters"` URL string `json:"url"` diff --git a/server/routes.go b/server/routes.go index 36c182e3d..3035c46f4 100644 --- a/server/routes.go +++ b/server/routes.go @@ -3,7 +3,7 @@ package server import ( "embed" "encoding/json" - "errors" + "fmt" "io" "log" "math" @@ -40,6 +40,7 @@ func generate(c *gin.Context) { req := api.GenerateRequest{ Options: api.DefaultOptions(), + Prompt: "", } if err := c.ShouldBindJSON(&req); err != nil { @@ -47,34 +48,28 @@ func generate(c *gin.Context) { return } - if remoteModel, _ := getRemote(req.Model); remoteModel != nil { - req.Model = remoteModel.FullName() - } - if _, err := os.Stat(req.Model); err != nil { - if !errors.Is(err, os.ErrNotExist) { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - req.Model = filepath.Join(cacheDir(), "models", req.Model+".bin") + model, err := GetModel(req.Model) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return } - templateNames := make([]string, 0, len(templates.Templates())) - for _, template := range templates.Templates() { - templateNames = append(templateNames, template.Name()) + templ, err := template.New("").Parse(model.Prompt) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return } - match, _ := matchRankOne(filepath.Base(req.Model), templateNames) - if template := templates.Lookup(match); template != nil { - var sb strings.Builder - if err := template.Execute(&sb, req); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - req.Prompt = sb.String() + var sb strings.Builder + if err = templ.Execute(&sb, req); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return } + req.Prompt = sb.String() - llm, err := llama.New(req.Model, req.Options) + fmt.Printf("prompt = >>>%s<<<\n", req.Prompt) + + llm, err := llama.New(model.ModelPath, req.Options) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -105,40 +100,84 @@ func pull(c *gin.Context) { return } - remote, err := getRemote(req.Model) - if err != nil { - c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) - return + ch := make(chan any) + go func() { + defer close(ch) + fn := func(status, digest string, total, completed int, percent float64) { + ch <- api.PullProgress{ + Status: status, + Digest: digest, + Total: total, + Completed: completed, + Percent: percent, + } + } + if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } } - // check if completed file exists - fi, err := os.Stat(remote.FullName()) - switch { - case errors.Is(err, os.ErrNotExist): - // noop, file doesn't exist so create it - case err != nil: - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - default: - c.JSON(http.StatusOK, api.PullProgress{ - Total: fi.Size(), - Completed: fi.Size(), - Percent: 100, - }) + streamResponse(c, ch) +} +func push(c *gin.Context) { + var req api.PushRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } ch := make(chan any) go func() { defer close(ch) - saveModel(remote, func(total, completed int64) { - ch <- api.PullProgress{ + fn := func(status, digest string, total, completed int, percent float64) { + ch <- api.PushProgress{ + Status: status, + Digest: digest, Total: total, Completed: completed, - Percent: float64(completed) / float64(total) * 100, + Percent: percent, } - }) + } + if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + }() + + streamResponse(c, ch) +} + +func create(c *gin.Context) { + var req api.CreateRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + return + } + + // NOTE consider passing the entire Modelfile in the json instead of the path to it + + file, err := os.Open(req.Path) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + return + } + defer file.Close() + + ch := make(chan any) + go func() { + defer close(ch) + fn := func(status string) { + ch <- api.CreateProgress{ + Status: status, + } + } + + if err := CreateModel(req.Name, file, fn); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + return + } }() streamResponse(c, ch) @@ -153,6 +192,8 @@ func Serve(ln net.Listener) error { r.POST("/api/pull", pull) r.POST("/api/generate", generate) + r.POST("/api/create", create) + r.POST("/api/push", push) log.Printf("Listening on %s", ln.Addr()) s := &http.Server{