client create modelfile

This commit is contained in:
Michael Yang
2023-11-14 14:07:40 -08:00
parent 3ca56b5ada
commit 1552cee59f
4 changed files with 147 additions and 14 deletions

View File

@@ -2,6 +2,7 @@ package server
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
@@ -649,6 +650,60 @@ func CopyModelHandler(c *gin.Context) {
}
}
func GetBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if _, err := os.Stat(path); err != nil {
c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
return
}
c.JSON(http.StatusOK, api.CreateBlobResponse{Path: path})
}
func CreateBlobHandler(c *gin.Context) {
hash := sha256.New()
temp, err := os.CreateTemp("", c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer temp.Close()
defer os.Remove(temp.Name())
if _, err := io.Copy(temp, io.TeeReader(c.Request.Body, hash)); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if fmt.Sprintf("sha256:%x", hash.Sum(nil)) != c.Param("digest") {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "digest does not match body"})
return
}
if err := temp.Close(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
targetPath, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := os.Rename(temp.Name(), targetPath); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, api.CreateBlobResponse{Path: targetPath})
}
var defaultAllowOrigins = []string{
"localhost",
"127.0.0.1",
@@ -708,6 +763,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
r.POST("/api/copy", CopyModelHandler)
r.DELETE("/api/delete", DeleteModelHandler)
r.POST("/api/show", ShowModelHandler)
r.POST("/api/blobs/:digest", CreateBlobHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) {
@@ -715,6 +771,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
})
r.Handle(method, "/api/tags", ListModelsHandler)
r.Handle(method, "/api/blobs/:digest/path", GetBlobHandler)
}
log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)