ollama/server/routes.go

1590 lines
40 KiB
Go
Raw Permalink Normal View History

package server
import (
2024-06-17 10:38:55 -07:00
"bytes"
2024-04-30 10:55:19 -07:00
"cmp"
"context"
"encoding/binary"
2023-07-06 10:40:11 -07:00
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"math"
"net"
"net/http"
"net/netip"
"os"
"os/signal"
2023-07-14 17:27:14 -07:00
"path/filepath"
2024-05-21 21:30:52 -07:00
"slices"
2023-07-06 10:40:11 -07:00
"strings"
"syscall"
2023-07-12 18:18:06 -07:00
"time"
2023-07-21 18:01:24 -07:00
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/model/mllama"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/runners"
2024-06-10 14:54:42 -07:00
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
2024-04-16 16:22:38 -07:00
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
2023-08-22 09:48:35 -07:00
var mode string = gin.DebugMode
2023-12-14 16:47:40 -08:00
type Server struct {
addr net.Addr
sched *Scheduler
2023-12-14 16:47:40 -08:00
}
2023-08-22 09:48:35 -07:00
func init() {
switch mode {
case gin.DebugMode:
case gin.ReleaseMode:
case gin.TestMode:
default:
mode = gin.DebugMode
}
gin.SetMode(mode)
}
2024-08-01 14:52:15 -07:00
var (
errRequired = errors.New("is required")
errBadTemplate = errors.New("template error")
)
2024-06-20 11:00:08 -07:00
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
return api.Options{}, err
}
if err := opts.FromMap(requestOpts); err != nil {
return api.Options{}, err
}
return opts, nil
}
// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
2024-06-17 10:38:55 -07:00
if name == "" {
return nil, nil, nil, fmt.Errorf("model %w", errRequired)
}
2024-06-17 10:38:55 -07:00
model, err := GetModel(name)
if err != nil {
return nil, nil, nil, err
}
2024-06-17 10:38:55 -07:00
if err := model.CheckCapabilities(caps...); err != nil {
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
}
2024-06-17 10:38:55 -07:00
opts, err := modelOptions(model, requestOpts)
if err != nil {
return nil, nil, nil, err
}
2024-06-17 10:38:55 -07:00
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
var runner *runnerRef
select {
2024-06-17 10:38:55 -07:00
case runner = <-runnerCh:
case err = <-errCh:
return nil, nil, nil, err
}
return runner.llama, model, &opts, nil
2024-06-17 10:38:55 -07:00
}
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
2024-06-17 10:38:55 -07:00
var req api.GenerateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
name := model.ParseName(req.Model)
if !name.IsValid() {
// Ideally this is "invalid model name" but we're keeping with
// what the API currently returns until we can change it.
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
// We cannot currently consolidate this into GetModel because all we'll
// induce infinite recursion given the current code structure.
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
model, err := GetModel(name.String())
if err != nil {
switch {
case errors.Is(err, fs.ErrNotExist):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == errtypes.InvalidModelNameErrMsg:
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
2024-09-11 16:36:21 -07:00
// expire the runner
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
s.sched.expireRunner(model)
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Response: "",
Done: true,
DoneReason: "unload",
})
return
}
if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
2024-06-17 10:38:55 -07:00
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
2024-06-10 14:54:42 -07:00
return
}
2024-06-17 10:38:55 -07:00
caps := []Capability{CapabilityCompletion}
if req.Suffix != "" {
caps = append(caps, CapabilityInsert)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
2024-06-17 10:38:55 -07:00
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
return
} else if err != nil {
2024-06-20 11:00:08 -07:00
handleScheduleError(c, req.Model, err)
return
}
checkpointLoaded := time.Now()
// load the model
2024-06-20 11:00:08 -07:00
if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: true,
DoneReason: "load",
})
2024-06-17 10:38:55 -07:00
return
}
isMllama := checkMllamaModelFamily(model)
if isMllama && len(req.Images) > 1 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"})
return
}
2024-06-17 10:38:55 -07:00
images := make([]llm.ImageData, len(req.Images))
for i := range req.Images {
if isMllama {
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
return
}
ar, ok := opts["aspectRatioIndex"].(int)
if !ok {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
return
}
buf := new(bytes.Buffer)
err = binary.Write(buf, binary.LittleEndian, data)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
return
}
images[i] = llm.ImageData{ID: i, Data: buf.Bytes(), AspectRatioID: ar}
} else {
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
}
2024-06-17 10:38:55 -07:00
}
2023-12-05 14:57:33 -05:00
2024-06-17 10:38:55 -07:00
prompt := req.Prompt
if !req.Raw {
tmpl := m.Template
2024-06-17 10:38:55 -07:00
if req.Template != "" {
tmpl, err = template.Parse(req.Template)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
var values template.Values
if req.Suffix != "" {
values.Prompt = prompt
values.Suffix = req.Suffix
} else {
var msgs []api.Message
if req.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
} else if m.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
}
2024-06-19 14:14:28 -07:00
if req.Context == nil {
msgs = append(msgs, m.Messages...)
}
for _, i := range images {
imgPrompt := ""
if isMllama {
imgPrompt = "<|image|>"
}
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]"+imgPrompt, i.ID)})
}
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
}
2024-06-19 14:14:28 -07:00
var b bytes.Buffer
if req.Context != nil {
slog.Warn("the context field is deprecated and will be removed in a future version of Ollama")
2024-08-01 19:56:15 +07:00
s, err := r.Detokenize(c.Request.Context(), req.Context)
2024-06-19 14:14:28 -07:00
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
2024-08-02 03:50:05 +07:00
b.WriteString(s)
2024-06-19 14:14:28 -07:00
}
2024-08-02 03:50:05 +07:00
if err := tmpl.Execute(&b, values); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
prompt = b.String()
}
slog.Debug("generate request", "images", len(images), "prompt", prompt)
ch := make(chan any)
go func() {
// TODO (jmorganca): avoid building the response twice both here and below
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
2024-06-17 10:38:55 -07:00
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
2024-05-09 13:30:14 -07:00
Model: req.Model,
CreatedAt: time.Now().UTC(),
Response: cr.Content,
Done: cr.Done,
DoneReason: cr.DoneReason,
2023-12-05 14:57:33 -05:00
Metrics: api.Metrics{
PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: cr.PromptEvalDuration,
EvalCount: cr.EvalCount,
EvalDuration: cr.EvalDuration,
2023-12-05 14:57:33 -05:00
},
}
if _, err := sb.WriteString(cr.Content); err != nil {
ch <- gin.H{"error": err.Error()}
}
if cr.Done {
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
2024-08-02 03:50:05 +07:00
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
2024-08-02 03:50:05 +07:00
res.Context = tokens
}
}
ch <- res
2024-06-17 10:38:55 -07:00
}); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
2024-06-17 10:38:55 -07:00
var r api.GenerateResponse
2023-12-05 14:57:33 -05:00
var sb strings.Builder
2024-06-17 10:38:55 -07:00
for rr := range ch {
switch t := rr.(type) {
case api.GenerateResponse:
2024-06-17 10:38:55 -07:00
sb.WriteString(t.Response)
r = t
case gin.H:
2024-06-17 10:38:55 -07:00
msg, ok := t["error"].(string)
if !ok {
msg = "unexpected error format in response"
}
2024-06-17 10:38:55 -07:00
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return
default:
2024-06-17 10:38:55 -07:00
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
return
}
}
2024-06-17 10:38:55 -07:00
r.Response = sb.String()
c.JSON(http.StatusOK, r)
return
}
streamResponse(c, ch)
}
func (s *Server) EmbedHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.EmbedRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
truncate := true
if req.Truncate != nil && !*req.Truncate {
truncate = false
}
var input []string
switch i := req.Input.(type) {
case string:
if len(i) > 0 {
input = append(input, i)
}
case []any:
for _, v := range i {
if _, ok := v.(string); !ok {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
input = append(input, v.(string))
}
default:
if req.Input != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
}
name, err := getExistingName(model.ParseName(req.Model))
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
}
checkpointLoaded := time.Now()
if len(input) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
kvData, err := getKVData(m.ModelPath, false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var count int
for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if !truncate {
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
return
}
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
count += len(tokens)
input[i] = s
}
var g errgroup.Group
embeddings := make([][]float32, len(input))
for i, text := range input {
g.Go(func() error {
embedding, err := r.Embedding(c.Request.Context(), text)
if err != nil {
return err
}
embeddings[i] = normalize(embedding)
return nil
})
}
if err := g.Wait(); err != nil {
slog.Error("embedding generation failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)})
return
}
resp := api.EmbedResponse{
Model: req.Model,
Embeddings: embeddings,
TotalDuration: time.Since(checkpointStart),
LoadDuration: checkpointLoaded.Sub(checkpointStart),
PromptEvalCount: count,
}
c.JSON(http.StatusOK, resp)
}
func normalize(vec []float32) []float32 {
var sum float32
for _, v := range vec {
sum += v * v
}
norm := float32(0.0)
if sum > 0 {
norm = float32(1.0 / math.Sqrt(float64(sum)))
}
for i := range vec {
vec[i] *= norm
}
return vec
}
func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
2024-06-17 10:38:55 -07:00
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
2024-06-17 10:38:55 -07:00
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
name := model.ParseName(req.Model)
if !name.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
if err != nil {
2024-06-20 11:00:08 -07:00
handleScheduleError(c, req.Model, err)
return
}
// an empty request loads the model
if req.Prompt == "" {
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
return
}
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embedding: %v", err)})
return
}
var e []float64
for _, v := range embedding {
e = append(e, float64(v))
}
resp := api.EmbeddingResponse{
Embedding: e,
}
c.JSON(http.StatusOK, resp)
}
func (s *Server) PullHandler(c *gin.Context) {
2023-07-11 11:54:22 -07:00
var req api.PullRequest
2023-10-18 16:08:42 -07:00
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-07-11 11:54:22 -07:00
return
}
name := model.ParseName(cmp.Or(req.Model, req.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
return
}
name, err = getExistingName(name)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ch := make(chan any)
go func() {
defer close(ch)
2023-07-18 18:51:30 -07:00
fn := func(r api.ProgressResponse) {
ch <- r
}
2023-07-18 18:51:30 -07:00
2024-02-14 11:29:49 -08:00
regOpts := &registryOptions{
Insecure: req.Insecure,
}
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
2023-07-20 12:12:08 -07:00
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
waitForStream(c, ch)
return
}
streamResponse(c, ch)
}
func (s *Server) PushHandler(c *gin.Context) {
var req api.PushRequest
2023-10-18 16:08:42 -07:00
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-07-11 11:54:22 -07:00
return
}
2023-07-06 10:40:11 -07:00
var mname string
if req.Model != "" {
mname = req.Model
} else if req.Name != "" {
mname = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
ch := make(chan any)
go func() {
defer close(ch)
2023-07-18 18:51:30 -07:00
fn := func(r api.ProgressResponse) {
ch <- r
}
2023-07-18 18:51:30 -07:00
2024-02-14 11:29:49 -08:00
regOpts := &registryOptions{
Insecure: req.Insecure,
}
2023-10-09 10:24:27 -07:00
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
name, err := getExistingName(model.ParseName(mname))
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
if err := PushModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
2023-07-20 12:12:08 -07:00
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
waitForStream(c, ch)
return
}
streamResponse(c, ch)
}
// getExistingName searches the models directory for the longest prefix match of
// the input name and returns the input name with all existing parts replaced
// with each part found. If no parts are found, the input name is returned as
// is.
func getExistingName(n model.Name) (model.Name, error) {
var zero model.Name
existing, err := Manifests(true)
if err != nil {
return zero, err
}
var set model.Name // tracks parts already canonicalized
for e := range existing {
if set.Host == "" && strings.EqualFold(e.Host, n.Host) {
n.Host = e.Host
}
if set.Namespace == "" && strings.EqualFold(e.Namespace, n.Namespace) {
n.Namespace = e.Namespace
}
if set.Model == "" && strings.EqualFold(e.Model, n.Model) {
n.Model = e.Model
}
if set.Tag == "" && strings.EqualFold(e.Tag, n.Tag) {
n.Tag = e.Tag
}
}
return n, nil
}
func (s *Server) DeleteHandler(c *gin.Context) {
var r api.DeleteRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
2023-10-18 16:08:42 -07:00
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
2023-10-18 16:08:42 -07:00
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-07-20 16:09:23 -07:00
return
}
n := model.ParseName(cmp.Or(r.Model, r.Name))
if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
return
}
2023-09-26 17:28:14 -07:00
n, err := getExistingName(n)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
return
}
m, err := ParseNamedManifest(n)
2023-09-26 17:28:14 -07:00
if err != nil {
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
2023-09-26 17:28:14 -07:00
return
}
if err := m.Remove(); err != nil {
2023-09-26 17:28:14 -07:00
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := m.RemoveLayers(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
2023-07-20 16:09:23 -07:00
}
func (s *Server) ShowHandler(c *gin.Context) {
2023-09-06 11:04:17 -07:00
var req api.ShowRequest
2023-10-18 16:08:42 -07:00
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-09-06 11:04:17 -07:00
return
}
if req.Model != "" {
2024-01-18 15:36:50 -08:00
// noop
} else if req.Name != "" {
2024-01-18 15:36:50 -08:00
req.Model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
resp, err := GetModelInfo(req)
2023-09-06 11:04:17 -07:00
if err != nil {
switch {
case os.IsNotExist(err):
2024-01-18 15:36:50 -08:00
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == errtypes.InvalidModelNameErrMsg:
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
2023-09-06 11:04:17 -07:00
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, resp)
}
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
name := model.ParseName(req.Model)
if !name.IsValid() {
return nil, errModelPathInvalid
}
name, err := getExistingName(name)
if err != nil {
return nil, err
}
m, err := GetModel(name.String())
2023-09-06 11:04:17 -07:00
if err != nil {
return nil, err
}
modelDetails := api.ModelDetails{
ParentModel: m.ParentModel,
Format: m.Config.ModelFormat,
Family: m.Config.ModelFamily,
Families: m.Config.ModelFamilies,
ParameterSize: m.Config.ModelType,
QuantizationLevel: m.Config.FileType,
}
if req.System != "" {
m.System = req.System
}
2024-06-17 10:38:55 -07:00
msgs := make([]api.Message, len(m.Messages))
for i, msg := range m.Messages {
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
2024-01-25 12:12:36 -08:00
}
manifest, err := ParseNamedManifest(name)
if err != nil {
return nil, err
}
2023-09-06 11:04:17 -07:00
resp := &api.ShowResponse{
License: strings.Join(m.License, "\n"),
System: m.System,
2024-06-10 14:54:42 -07:00
Template: m.Template.String(),
Details: modelDetails,
Messages: msgs,
ModifiedAt: manifest.fi.ModTime(),
2023-09-06 11:04:17 -07:00
}
var params []string
cs := 30
for k, v := range m.Options {
2023-09-06 11:04:17 -07:00
switch val := v.(type) {
case []interface{}:
for _, nv := range val {
2024-01-16 10:34:44 -08:00
params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
2023-09-06 11:04:17 -07:00
}
2024-01-16 10:34:44 -08:00
default:
params = append(params, fmt.Sprintf("%-*s %#v", cs, k, v))
2023-09-06 11:04:17 -07:00
}
}
resp.Parameters = strings.Join(params, "\n")
for k, v := range req.Options {
if _, ok := req.Options[k]; ok {
m.Options[k] = v
}
}
var sb strings.Builder
fmt.Fprintln(&sb, "# Modelfile generated by \"ollama show\"")
fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:")
fmt.Fprintf(&sb, "# FROM %s\n\n", m.ShortName)
fmt.Fprint(&sb, m.String())
resp.Modelfile = sb.String()
kvData, err := getKVData(m.ModelPath, req.Verbose)
if err != nil {
return nil, err
}
delete(kvData, "general.name")
delete(kvData, "tokenizer.chat_template")
resp.ModelInfo = kvData
if len(m.ProjectorPaths) > 0 {
projectorData, err := getKVData(m.ProjectorPaths[0], req.Verbose)
if err != nil {
return nil, err
}
resp.ProjectorInfo = projectorData
}
2023-09-06 11:04:17 -07:00
return resp, nil
}
func getKVData(digest string, verbose bool) (llm.KV, error) {
maxArraySize := 0
if verbose {
maxArraySize = -1
}
kvData, err := llm.LoadModel(digest, maxArraySize)
if err != nil {
return nil, err
}
kv := kvData.KV()
if !verbose {
for k := range kv {
if t, ok := kv[k].([]any); len(t) > 5 && ok {
kv[k] = []any{}
}
}
}
return kv, nil
}
func (s *Server) ListHandler(c *gin.Context) {
ms, err := Manifests(true)
2023-07-18 09:09:45 -07:00
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
2023-08-30 14:14:12 -04:00
models := []api.ListModelResponse{}
2024-05-06 16:34:13 -07:00
for n, m := range ms {
var cf ConfigV2
if m.Config.Digest != "" {
f, err := m.Config.Open()
if err != nil {
slog.Warn("bad manifest filepath", "name", n, "error", err)
continue
}
defer f.Close()
if err := json.NewDecoder(f).Decode(&cf); err != nil {
slog.Warn("bad manifest config", "name", n, "error", err)
continue
}
2023-07-18 09:09:45 -07:00
}
2023-08-30 14:14:12 -04:00
2024-05-06 16:34:13 -07:00
// tag should never be masked
models = append(models, api.ListModelResponse{
2024-05-06 16:34:13 -07:00
Model: n.DisplayShortest(),
Name: n.DisplayShortest(),
Size: m.Size(),
Digest: m.digest,
ModifiedAt: m.fi.ModTime(),
Details: api.ModelDetails{
Format: cf.ModelFormat,
Family: cf.ModelFamily,
Families: cf.ModelFamilies,
ParameterSize: cf.ModelType,
QuantizationLevel: cf.FileType,
},
})
2023-07-18 09:09:45 -07:00
}
slices.SortStableFunc(models, func(i, j api.ListModelResponse) int {
2024-04-17 14:54:14 -07:00
// most recently modified first
return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
})
2023-07-19 15:00:28 -07:00
c.JSON(http.StatusOK, api.ListResponse{Models: models})
2023-07-18 09:09:45 -07:00
}
func (s *Server) CopyHandler(c *gin.Context) {
2024-04-16 16:22:38 -07:00
var r api.CopyRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
2023-10-18 16:08:42 -07:00
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
2024-04-16 16:22:38 -07:00
} else if err != nil {
2023-10-18 16:08:42 -07:00
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-07-24 11:27:28 -04:00
return
}
2024-04-16 16:22:38 -07:00
src := model.ParseName(r.Source)
if !src.IsValid() {
2024-05-01 12:39:05 -07:00
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)})
return
}
src, err := getExistingName(src)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
2024-04-16 16:22:38 -07:00
dst := model.ParseName(r.Destination)
if !dst.IsValid() {
2024-05-07 17:35:52 -07:00
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)})
2023-07-24 11:27:28 -04:00
return
}
dst, err = getExistingName(dst)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
2024-04-16 16:22:38 -07:00
if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
} else if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
2023-07-24 11:27:28 -04:00
}
func (s *Server) HeadBlobHandler(c *gin.Context) {
2023-11-14 14:07:40 -08:00
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if _, err := os.Stat(path); err != nil {
c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
return
}
2023-11-15 13:55:37 -08:00
c.Status(http.StatusOK)
2023-11-14 14:07:40 -08:00
}
func (s *Server) CreateBlobHandler(c *gin.Context) {
2024-05-20 14:58:27 -07:00
if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
p, err := GetBlobsPath(ib)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) {
2024-05-20 14:58:27 -07:00
slog.Info("evicting intermediate blob which no longer exists", "digest", ib)
delete(intermediateBlobs, c.Param("digest"))
} else if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
} else {
c.Status(http.StatusOK)
return
}
}
2024-04-05 09:30:09 -07:00
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
_, err = os.Stat(path)
switch {
case errors.Is(err, os.ErrNotExist):
// noop
case err != nil:
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
default:
c.Status(http.StatusOK)
return
}
2023-11-24 12:01:23 -08:00
layer, err := NewLayer(c.Request.Body, "")
2023-11-17 15:21:57 -08:00
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
2023-11-24 12:01:23 -08:00
if layer.Digest != c.Param("digest") {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
2023-11-14 14:07:40 -08:00
return
}
2023-11-15 13:55:37 -08:00
c.Status(http.StatusCreated)
2023-11-14 14:07:40 -08:00
}
2024-03-09 00:22:08 -08:00
func isLocalIP(ip netip.Addr) bool {
if interfaces, err := net.Interfaces(); err == nil {
for _, iface := range interfaces {
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, a := range addrs {
if parsed, _, err := net.ParseCIDR(a.String()); err == nil {
if parsed.String() == ip.String() {
return true
}
}
}
}
}
return false
}
func allowedHost(host string) bool {
host = strings.ToLower(host)
2024-03-09 00:22:08 -08:00
if host == "" || host == "localhost" {
return true
}
if hostname, err := os.Hostname(); err == nil && host == strings.ToLower(hostname) {
return true
}
2024-08-01 14:52:15 -07:00
tlds := []string{
2024-03-08 23:23:59 -08:00
"localhost",
"local",
"internal",
2023-12-14 16:47:40 -08:00
}
2024-03-08 23:29:53 -08:00
// check if the host is a local TLD
for _, tld := range tlds {
if strings.HasSuffix(host, "."+tld) {
return true
}
}
2024-03-08 23:29:53 -08:00
return false
2024-03-08 23:23:59 -08:00
}
2024-03-08 23:23:59 -08:00
func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
return func(c *gin.Context) {
if addr == nil {
c.Next()
return
}
2024-03-09 00:22:08 -08:00
if addr, err := netip.ParseAddrPort(addr.String()); err == nil && !addr.Addr().IsLoopback() {
c.Next()
return
}
host, _, err := net.SplitHostPort(c.Request.Host)
if err != nil {
host = c.Request.Host
}
2024-03-08 23:23:59 -08:00
if addr, err := netip.ParseAddr(host); err == nil {
2024-03-09 00:22:08 -08:00
if addr.IsLoopback() || addr.IsPrivate() || addr.IsUnspecified() || isLocalIP(addr) {
2024-03-08 23:23:59 -08:00
c.Next()
return
}
}
if allowedHost(host) {
2024-05-21 22:21:04 -07:00
if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
return
}
c.AbortWithStatus(http.StatusForbidden)
}
2023-12-14 16:47:40 -08:00
}
2023-12-14 16:47:40 -08:00
func (s *Server) GenerateRoutes() http.Handler {
2023-07-21 18:01:24 -07:00
config := cors.DefaultConfig()
config.AllowWildcard = true
config.AllowBrowserExtensions = true
config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
openAIProperties := []string{"lang", "package-version", "os", "arch", "retry-count", "runtime", "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval"}
for _, prop := range openAIProperties {
config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
}
2024-07-03 17:02:07 -07:00
config.AllowOrigins = envconfig.Origins()
2023-07-21 18:01:24 -07:00
2023-07-05 15:37:33 -04:00
r := gin.Default()
r.Use(
cors.New(config),
allowedHostsMiddleware(s.addr),
)
2023-07-05 15:37:33 -04:00
r.POST("/api/pull", s.PullHandler)
r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/create", s.CreateHandler)
r.POST("/api/push", s.PushHandler)
r.POST("/api/copy", s.CopyHandler)
r.DELETE("/api/delete", s.DeleteHandler)
r.POST("/api/show", s.ShowHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.GET("/api/ps", s.PsHandler)
// Compatibility endpoints
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
2023-09-21 16:38:03 -07:00
for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) {
c.String(http.StatusOK, "Ollama is running")
})
r.Handle(method, "/api/tags", s.ListHandler)
2023-10-12 15:45:07 -07:00
r.Handle(method, "/api/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"version": version.Version})
})
2023-09-21 16:38:03 -07:00
}
2023-12-14 16:47:40 -08:00
return r
}
func Serve(ln net.Listener) error {
level := slog.LevelInfo
2024-07-03 16:00:54 -07:00
if envconfig.Debug() {
level = slog.LevelDebug
}
slog.Info("server config", "env", envconfig.Values())
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
AddSource: true,
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
if attr.Key == slog.SourceKey {
source := attr.Value.Any().(*slog.Source)
source.File = filepath.Base(source.File)
}
return attr
},
})
slog.SetDefault(slog.New(handler))
blobsDir, err := GetBlobsPath("")
if err != nil {
return err
}
if err := fixBlobs(blobsDir); err != nil {
return err
}
2024-07-03 17:22:13 -07:00
if !envconfig.NoPrune() {
if _, err := Manifests(false); err != nil {
slog.Warn("corrupt manifests detected, skipping prune operation. Re-pull or delete to clear", "error", err)
} else {
// clean up unused layers and manifests
if err := PruneLayers(); err != nil {
return err
}
2023-12-14 16:47:40 -08:00
manifestsPath, err := GetManifestPath()
if err != nil {
return err
}
2023-12-14 16:47:40 -08:00
if err := PruneDirectory(manifestsPath); err != nil {
return err
}
2023-12-14 16:47:40 -08:00
}
}
ctx, done := context.WithCancel(context.Background())
schedCtx, schedDone := context.WithCancel(ctx)
sched := InitScheduler(schedCtx)
s := &Server{addr: ln.Addr(), sched: sched}
http.Handle("/", s.GenerateRoutes())
2023-12-14 16:47:40 -08:00
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
2023-12-14 16:47:40 -08:00
srvr := &http.Server{
// Use http.DefaultServeMux so we get net/http/pprof for
// free.
//
// TODO(bmizerany): Decide if we want to make this
// configurable so it is not exposed by default, or allow
// users to bind it to a different port. This was a quick
// and easy way to get pprof, but it may not be the best
// way.
Handler: nil,
}
// listen for a ctrl+c and stop any loaded llm
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-signals
srvr.Close()
schedDone()
sched.unloadAllRunners()
done()
}()
build: Make target improvements (#7499) * llama: wire up builtin runner This adds a new entrypoint into the ollama CLI to run the cgo built runner. On Mac arm64, this will have GPU support, but on all other platforms it will be the lowest common denominator CPU build. After we fully transition to the new Go runners more tech-debt can be removed and we can stop building the "default" runner via make and rely on the builtin always. * build: Make target improvements Add a few new targets and help for building locally. This also adjusts the runner lookup to favor local builds, then runners relative to the executable, and finally payloads. * Support customized CPU flags for runners This implements a simplified custom CPU flags pattern for the runners. When built without overrides, the runner name contains the vector flag we check for (AVX) to ensure we don't try to run on unsupported systems and crash. If the user builds a customized set, we omit the naming scheme and don't check for compatibility. This avoids checking requirements at runtime, so that logic has been removed as well. This can be used to build GPU runners with no vector flags, or CPU/GPU runners with additional flags (e.g. AVX512) enabled. * Use relative paths If the user checks out the repo in a path that contains spaces, make gets really confused so use relative paths for everything in-repo to avoid breakage. * Remove payloads from main binary * install: clean up prior libraries This removes support for v0.3.6 and older versions (before the tar bundle) and ensures we clean up prior libraries before extracting the bundle(s). Without this change, runners and dependent libraries could leak when we update and lead to subtle runtime errors.
2024-12-10 09:47:19 -08:00
// Locate and log what runners are present at startup
var runnerNames []string
for v := range runners.GetAvailableServers() {
runnerNames = append(runnerNames, v)
}
build: Make target improvements (#7499) * llama: wire up builtin runner This adds a new entrypoint into the ollama CLI to run the cgo built runner. On Mac arm64, this will have GPU support, but on all other platforms it will be the lowest common denominator CPU build. After we fully transition to the new Go runners more tech-debt can be removed and we can stop building the "default" runner via make and rely on the builtin always. * build: Make target improvements Add a few new targets and help for building locally. This also adjusts the runner lookup to favor local builds, then runners relative to the executable, and finally payloads. * Support customized CPU flags for runners This implements a simplified custom CPU flags pattern for the runners. When built without overrides, the runner name contains the vector flag we check for (AVX) to ensure we don't try to run on unsupported systems and crash. If the user builds a customized set, we omit the naming scheme and don't check for compatibility. This avoids checking requirements at runtime, so that logic has been removed as well. This can be used to build GPU runners with no vector flags, or CPU/GPU runners with additional flags (e.g. AVX512) enabled. * Use relative paths If the user checks out the repo in a path that contains spaces, make gets really confused so use relative paths for everything in-repo to avoid breakage. * Remove payloads from main binary * install: clean up prior libraries This removes support for v0.3.6 and older versions (before the tar bundle) and ensures we clean up prior libraries before extracting the bundle(s). Without this change, runners and dependent libraries could leak when we update and lead to subtle runtime errors.
2024-12-10 09:47:19 -08:00
slog.Info("Dynamic LLM libraries", "runners", runnerNames)
slog.Debug("Override detection logic by setting OLLAMA_LLM_LIBRARY")
s.sched.Run(schedCtx)
// At startup we retrieve GPU information so we can get log messages before loading a model
// This will log warnings to the log in case we have problems with detected GPUs
gpus := discover.GetGPUInfo()
gpus.LogDetails()
err = srvr.Serve(ln)
// If server is closed from the signal handler, wait for the ctx to be done
// otherwise error out quickly
if !errors.Is(err, http.ErrServerClosed) {
return err
}
<-ctx.Done()
2024-05-16 16:25:38 -07:00
return nil
}
2023-07-06 10:40:11 -07:00
func waitForStream(c *gin.Context, ch chan interface{}) {
c.Header("Content-Type", "application/json")
for resp := range ch {
switch r := resp.(type) {
case api.ProgressResponse:
if r.Status == "success" {
c.JSON(http.StatusOK, r)
return
}
case gin.H:
status, ok := r["status"].(int)
if !ok {
status = http.StatusInternalServerError
}
if errorMsg, ok := r["error"].(string); ok {
c.JSON(status, gin.H{"error": errorMsg})
return
} else {
c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
return
}
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"})
return
}
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected end of progress response"})
}
func streamResponse(c *gin.Context, ch chan any) {
c.Header("Content-Type", "application/x-ndjson")
2023-07-11 11:54:22 -07:00
c.Stream(func(w io.Writer) bool {
val, ok := <-ch
if !ok {
return false
}
bts, err := json.Marshal(val)
if err != nil {
slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
2023-07-11 11:54:22 -07:00
return false
}
// Delineate chunks with new-line delimiter
2023-07-11 11:54:22 -07:00
bts = append(bts, '\n')
if _, err := w.Write(bts); err != nil {
slog.Info(fmt.Sprintf("streamResponse: w.Write failed with %s", err))
2023-07-11 11:54:22 -07:00
return false
}
return true
})
}
2023-12-05 14:57:33 -05:00
func (s *Server) PsHandler(c *gin.Context) {
models := []api.ProcessModelResponse{}
for _, v := range s.sched.loaded {
model := v.model
modelDetails := api.ModelDetails{
Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies,
ParameterSize: model.Config.ModelType,
QuantizationLevel: model.Config.FileType,
}
mr := api.ProcessModelResponse{
Model: model.ShortName,
Name: model.ShortName,
Size: int64(v.estimatedTotal),
SizeVRAM: int64(v.estimatedVRAM),
Digest: model.Digest,
Details: modelDetails,
ExpiresAt: v.expiresAt,
}
// The scheduler waits to set expiresAt, so if a model is loading it's
// possible that it will be set to the unix epoch. For those cases, just
// calculate the time w/ the sessionDuration instead.
var epoch time.Time
if v.expiresAt == epoch {
mr.ExpiresAt = time.Now().Add(v.sessionDuration)
}
models = append(models, mr)
}
slices.SortStableFunc(models, func(i, j api.ProcessModelResponse) int {
// longest duration remaining listed first
return cmp.Compare(j.ExpiresAt.Unix(), i.ExpiresAt.Unix())
})
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
}
func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
2023-12-05 14:57:33 -05:00
var req api.ChatRequest
2024-06-17 10:38:55 -07:00
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
2023-12-05 14:57:33 -05:00
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
2024-06-17 10:38:55 -07:00
} else if err != nil {
2023-12-05 14:57:33 -05:00
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
2024-09-11 16:36:21 -07:00
// expire the runner
if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
model, err := GetModel(req.Model)
if err != nil {
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == errtypes.InvalidModelNameErrMsg:
2024-09-11 16:36:21 -07:00
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
s.sched.expireRunner(model)
c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant"},
Done: true,
DoneReason: "unload",
})
return
}
2024-06-17 10:38:55 -07:00
caps := []Capability{CapabilityCompletion}
if len(req.Tools) > 0 {
2024-06-20 13:45:47 -07:00
caps = append(caps, CapabilityTools)
}
name := model.ParseName(req.Model)
if !name.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
2024-06-17 10:38:55 -07:00
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
2023-12-05 14:57:33 -05:00
return
2024-06-17 10:38:55 -07:00
} else if err != nil {
2024-06-20 11:00:08 -07:00
handleScheduleError(c, req.Model, err)
2023-12-05 14:57:33 -05:00
return
}
2024-01-31 17:39:38 -08:00
checkpointLoaded := time.Now()
2024-06-17 10:38:55 -07:00
if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{
2024-05-09 13:30:14 -07:00
Model: req.Model,
2024-06-17 10:38:55 -07:00
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant"},
2024-05-09 13:30:14 -07:00
Done: true,
DoneReason: "load",
2024-06-17 10:38:55 -07:00
})
return
}
2024-06-19 14:14:28 -07:00
msgs := append(m.Messages, req.Messages...)
if req.Messages[0].Role != "system" && m.System != "" {
2024-06-19 14:14:28 -07:00
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
}
2024-06-19 14:14:28 -07:00
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
2024-06-17 10:38:55 -07:00
if err != nil {
2024-11-27 13:40:57 -08:00
slog.Error("chat prompt error", "error", err)
2024-06-17 10:38:55 -07:00
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
2024-06-17 10:38:55 -07:00
slog.Debug("chat request", "images", len(images), "prompt", prompt)
2023-12-05 14:57:33 -05:00
ch := make(chan any)
go func() {
defer close(ch)
2024-11-27 13:40:57 -08:00
var sb strings.Builder
var toolCallIndex int = 0
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
2024-06-17 10:38:55 -07:00
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
2024-06-17 10:38:55 -07:00
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
2024-05-09 13:30:14 -07:00
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
DoneReason: r.DoneReason,
2023-12-05 14:57:33 -05:00
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if r.Done {
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
2024-11-27 13:40:57 -08:00
// TODO: tool call checking and filtering should be moved outside of this callback once streaming
// however this was a simple change for now without reworking streaming logic of this (and other)
// handlers
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
ch <- res
return
}
// Streaming tool calls:
// If tools are recognized, use a flag to track the sending of a tool downstream
// This ensures that content is cleared from the message on the last chunk sent
sb.WriteString(r.Content)
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
res.Message.ToolCalls = toolCalls
for i := range toolCalls {
toolCalls[i].Function.Index = toolCallIndex
toolCallIndex++
}
2024-11-27 13:40:57 -08:00
res.Message.Content = ""
sb.Reset()
ch <- res
return
}
if r.Done {
// Send any remaining content if no tool calls were detected
if toolCallIndex == 0 {
2024-11-27 13:40:57 -08:00
res.Message.Content = sb.String()
}
ch <- res
}
2024-06-17 10:38:55 -07:00
}); err != nil {
2023-12-05 14:57:33 -05:00
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
2024-06-20 13:45:47 -07:00
var resp api.ChatResponse
2023-12-05 14:57:33 -05:00
var sb strings.Builder
2024-06-17 10:38:55 -07:00
for rr := range ch {
switch t := rr.(type) {
case api.ChatResponse:
2024-06-17 10:38:55 -07:00
sb.WriteString(t.Message.Content)
2024-06-20 13:45:47 -07:00
resp = t
case gin.H:
2024-06-17 10:38:55 -07:00
msg, ok := t["error"].(string)
if !ok {
msg = "unexpected error format in response"
}
2024-06-17 10:38:55 -07:00
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return
default:
2024-06-17 10:38:55 -07:00
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
return
2023-12-05 14:57:33 -05:00
}
}
2024-06-20 13:45:47 -07:00
resp.Message.Content = sb.String()
if len(req.Tools) > 0 {
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
resp.Message.ToolCalls = toolCalls
resp.Message.Content = ""
}
2024-06-20 13:45:47 -07:00
}
c.JSON(http.StatusOK, resp)
2023-12-05 14:57:33 -05:00
return
}
streamResponse(c, ch)
}
2024-06-20 11:00:08 -07:00
func handleScheduleError(c *gin.Context, name string, err error) {
2024-06-17 10:38:55 -07:00
switch {
case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
2024-06-20 11:00:08 -07:00
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2024-06-17 10:38:55 -07:00
case errors.Is(err, context.Canceled):
c.JSON(499, gin.H{"error": "request canceled"})
2024-06-17 10:38:55 -07:00
case errors.Is(err, ErrMaxQueue):
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
2024-06-20 11:00:08 -07:00
case errors.Is(err, os.ErrNotExist):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
2024-06-17 10:38:55 -07:00
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}