engine: add remote proxy (#12307)

This commit is contained in:
Patrick Devine
2025-09-17 14:40:53 -07:00
committed by GitHub
parent 9c5bf342bc
commit 8b894933a7
12 changed files with 948 additions and 100 deletions

View File

@@ -10,8 +10,11 @@ import (
"io"
"io/fs"
"log/slog"
"net"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"slices"
"strings"
@@ -39,6 +42,14 @@ var (
)
func (s *Server) CreateHandler(c *gin.Context) {
config := &ConfigV2{
OS: "linux",
Architecture: "amd64",
RootFS: RootFS{
Type: "layers",
},
}
var r api.CreateRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -48,6 +59,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
return
}
config.Renderer = r.Renderer
config.Parser = r.Parser
for v := range r.Files {
if !fs.ValidPath(v) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
@@ -77,20 +91,34 @@ func (s *Server) CreateHandler(c *gin.Context) {
oldManifest, _ := ParseNamedManifest(name)
var baseLayers []*layerGGML
var err error
var remote bool
if r.From != "" {
slog.Debug("create model from model name")
slog.Debug("create model from model name", "from", r.From)
fromName := model.ParseName(r.From)
if !fromName.IsValid() {
ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest}
return
}
if r.RemoteHost != "" {
ru, err := remoteURL(r.RemoteHost)
if err != nil {
ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest}
return
}
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
config.RemoteModel = r.From
config.RemoteHost = ru
remote = true
} else {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
baseLayers, err = parseFromModel(ctx, fromName, fn)
if err != nil {
ch <- gin.H{"error": err.Error()}
baseLayers, err = parseFromModel(ctx, fromName, fn)
if err != nil {
ch <- gin.H{"error": err.Error()}
}
}
} else if r.Files != nil {
baseLayers, err = convertModelFromFiles(r.Files, baseLayers, false, fn)
@@ -110,7 +138,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
}
var adapterLayers []*layerGGML
if r.Adapters != nil {
if !remote && r.Adapters != nil {
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
if err != nil {
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} {
@@ -128,7 +156,56 @@ func (s *Server) CreateHandler(c *gin.Context) {
baseLayers = append(baseLayers, adapterLayers...)
}
if err := createModel(r, name, baseLayers, fn); err != nil {
// Info is not currently exposed by Modelfiles, but allows overriding various
// config values
if r.Info != nil {
caps, ok := r.Info["capabilities"]
if ok {
switch tcaps := caps.(type) {
case []any:
caps := make([]string, len(tcaps))
for i, c := range tcaps {
str, ok := c.(string)
if !ok {
continue
}
caps[i] = str
}
config.Capabilities = append(config.Capabilities, caps...)
}
}
strFromInfo := func(k string) string {
v, ok := r.Info[k]
if ok {
val := v.(string)
return val
}
return ""
}
vFromInfo := func(k string) float64 {
v, ok := r.Info[k]
if ok {
val := v.(float64)
return val
}
return 0
}
config.ModelFamily = strFromInfo("model_family")
if config.ModelFamily != "" {
config.ModelFamilies = []string{config.ModelFamily}
}
config.BaseName = strFromInfo("base_name")
config.FileType = strFromInfo("quantization_level")
config.ModelType = strFromInfo("parameter_size")
config.ContextLen = int(vFromInfo("context_length"))
config.EmbedLen = int(vFromInfo("embedding_length"))
}
if err := createModel(r, name, baseLayers, config, fn); err != nil {
if errors.Is(err, errBadTemplate) {
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
return
@@ -154,6 +231,51 @@ func (s *Server) CreateHandler(c *gin.Context) {
streamResponse(c, ch)
}
func remoteURL(raw string) (string, error) {
// Specialcase: user supplied only a path ("/foo/bar").
if strings.HasPrefix(raw, "/") {
return (&url.URL{
Scheme: "http",
Host: net.JoinHostPort("localhost", "11434"),
Path: path.Clean(raw),
}).String(), nil
}
if !strings.Contains(raw, "://") {
raw = "http://" + raw
}
if raw == "ollama.com" || raw == "http://ollama.com" {
raw = "https://ollama.com:443"
}
u, err := url.Parse(raw)
if err != nil {
return "", fmt.Errorf("parse error: %w", err)
}
if u.Host == "" {
u.Host = "localhost"
}
hostPart, portPart, err := net.SplitHostPort(u.Host)
if err == nil {
u.Host = net.JoinHostPort(hostPart, portPart)
} else {
u.Host = net.JoinHostPort(u.Host, "11434")
}
if u.Path != "" {
u.Path = path.Clean(u.Path)
}
if u.Path == "/" {
u.Path = ""
}
return u.String(), nil
}
func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
switch detectModelTypeFromFiles(files) {
case "safetensors":
@@ -316,17 +438,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
return ggml.KV{}, fmt.Errorf("no base model was found")
}
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, fn func(resp api.ProgressResponse)) (err error) {
config := ConfigV2{
OS: "linux",
Architecture: "amd64",
RootFS: RootFS{
Type: "layers",
},
Renderer: r.Renderer,
Parser: r.Parser,
}
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
var layers []Layer
for _, layer := range baseLayers {
if layer.GGML != nil {
@@ -406,7 +518,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
return err
}
configLayer, err := createConfigLayer(layers, config)
configLayer, err := createConfigLayer(layers, *config)
if err != nil {
return err
}