update message processing

This commit is contained in:
Michael Yang
2024-06-17 10:38:55 -07:00
parent 78fb33dd07
commit 269ed6e6a2
6 changed files with 685 additions and 720 deletions

View File

@@ -1,13 +1,13 @@
package server
import (
"bytes"
"cmp"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"net"
"net/http"
@@ -67,163 +67,140 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
return opts, nil
}
func isSupportedImageType(image []byte) bool {
contentType := http.DetectContentType(image)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
return slices.Contains(allowedTypes, contentType)
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (*runnerRef, error) {
if name == "" {
return nil, errors.New("model is required")
}
model, err := GetModel(name)
if err != nil {
return nil, err
}
if err := model.CheckCapabilities(caps...); err != nil {
return nil, fmt.Errorf("%s %w", name, err)
}
opts, err := modelOptions(model, requestOpts)
if err != nil {
return nil, err
}
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
var runner *runnerRef
select {
case runner = <-runnerCh:
case err = <-errCh:
return nil, err
}
return runner, nil
}
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// validate the request
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
if req.Format != "" && req.Format != "json" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
return
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
} else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return
}
for _, img := range req.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return
}
}
model, err := GetModel(req.Model)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
caps := []Capability{CapabilityCompletion}
r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
return
} else if err != nil {
handleScheduleError(c, err)
return
}
if !model.Has(CapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
return
images := make([]llm.ImageData, len(req.Images))
for i := range req.Images {
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
}
opts, err := modelOptions(model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return
}
// an empty request loads the model
// note: for a short while template was used in lieu
// of `raw` mode so we need to check for it too
if req.Prompt == "" && req.Template == "" && req.System == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
DoneReason: "load",
})
return
}
tmpl, err := template.Parse(req.Template)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
checkpointLoaded := time.Now()
var prompt string
switch {
case req.Raw:
prompt = req.Prompt
case req.Prompt != "":
if req.Template == "" {
tmpl = model.Template
prompt := req.Prompt
if !req.Raw {
var msgs []api.Message
if req.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
} else if r.model.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: r.model.System})
}
if req.System == "" {
req.System = model.System
if req.Prompt != "" {
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
}
slog.Debug("generate handler", "prompt", req.Prompt)
slog.Debug("generate handler", "template", req.Template)
slog.Debug("generate handler", "system", req.System)
var sb strings.Builder
for i := range req.Images {
fmt.Fprintf(&sb, "[img-%d] ", i)
}
sb.WriteString(req.Prompt)
p, err := Prompt(tmpl, req.System, sb.String(), "", true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
if len(msgs) == 0 {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: true,
DoneReason: "load",
})
return
}
sb.Reset()
tmpl := r.model.Template
if req.Template != "" {
tmpl, err = template.Parse(req.Template)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
var b bytes.Buffer
if req.Context != nil {
prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
s, err := r.llama.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sb.WriteString(prev)
b.WriteString(s)
}
sb.WriteString(p)
if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
prompt = sb.String()
prompt = b.String()
}
slog.Debug("generate handler", "prompt", prompt)
slog.Debug("generate request", "prompt", prompt, "images", images)
ch := make(chan any)
var generated strings.Builder
go func() {
defer close(ch)
fn := func(r llm.CompletionResponse) {
// Build up the full response
if _, err := generated.WriteString(r.Content); err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp := api.GenerateResponse{
if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: *r.Options,
}, func(r llm.CompletionResponse) {
ch <- api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: r.Done,
Response: r.Content,
Done: r.Done,
DoneReason: r.DoneReason,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
@@ -232,77 +209,35 @@ func (s *Server) GenerateHandler(c *gin.Context) {
EvalDuration: r.EvalDuration,
},
}
if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// TODO (jmorganca): encode() should not strip special tokens
tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp.Context = append(req.Context, tokens...)
}
}
ch <- resp
}
var images []llm.ImageData
for i := range req.Images {
images = append(images, llm.ImageData{
ID: i,
Data: req.Images[i],
})
}
// Start prediction
req := llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,
Options: opts,
}
if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
}); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
// Accumulate responses into the final response
var final api.GenerateResponse
var r api.GenerateResponse
var sb strings.Builder
for resp := range ch {
switch r := resp.(type) {
for rr := range ch {
switch t := rr.(type) {
case api.GenerateResponse:
sb.WriteString(r.Response)
final = r
sb.WriteString(t.Response)
r = t
case gin.H:
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
msg, ok := t["error"].(string)
if !ok {
msg = "unexpected error format in response"
}
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
return
}
}
final.Response = sb.String()
c.JSON(http.StatusOK, final)
r.Response = sb.String()
c.JSON(http.StatusOK, r)
return
}
@@ -311,44 +246,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
model, err := GetModel(req.Model)
r, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
handleScheduleError(c, err)
return
}
@@ -358,17 +266,14 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
embedding, err := r.llama.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
resp := api.EmbeddingResponse{
Embedding: embedding,
}
c.JSON(http.StatusOK, resp)
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
}
func (s *Server) PullModelHandler(c *gin.Context) {
@@ -649,9 +554,9 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
}
msgs := make([]api.Message, 0)
for _, msg := range m.Messages {
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
msgs := make([]api.Message, len(m.Messages))
for i, msg := range m.Messages {
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
}
n := model.ParseName(req.Model)
@@ -1214,132 +1119,55 @@ func (s *Server) ProcessHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
}
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) {
return runner.llama.Tokenize(ctx, s)
}
prompt, err := ChatPrompt(template, messages, numCtx, encode)
if err != nil {
return "", err
}
return prompt, nil
}
func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.ChatRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// validate the request
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
caps := []Capability{CapabilityCompletion}
r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
return
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
} else if err != nil {
handleScheduleError(c, err)
return
}
model, err := GetModel(req.Model)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if !model.Has(CapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return
}
checkpointLoaded := time.Now()
// if the first message is not a system message, then add the model's default system message
if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
req.Messages = append([]api.Message{
{
Role: "system",
Content: model.System,
},
}, req.Messages...)
}
prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// an empty request loads the model
if len(req.Messages) == 0 || prompt == "" {
resp := api.ChatResponse{
CreatedAt: time.Now().UTC(),
if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant"},
Done: true,
DoneReason: "load",
Message: api.Message{Role: "assistant"},
}
c.JSON(http.StatusOK, resp)
})
return
}
// only send images that are in the prompt
var i int
var images []llm.ImageData
for _, m := range req.Messages {
for _, img := range m.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return
}
if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
images = append(images, llm.ImageData{Data: img, ID: i})
}
i += 1
}
prompt, images, err := chatPrompt(c.Request.Context(), r, req.Messages)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
slog.Debug("chat handler", "prompt", prompt, "images", len(images))
slog.Debug("chat request", "images", len(images), "prompt", prompt)
ch := make(chan any)
go func() {
defer close(ch)
fn := func(r llm.CompletionResponse) {
resp := api.ChatResponse{
if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: *r.Options,
}, func(r llm.CompletionResponse) {
ch <- api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
@@ -1352,64 +1180,48 @@ func (s *Server) ChatHandler(c *gin.Context) {
EvalDuration: r.EvalDuration,
},
}
if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
ch <- resp
}
if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,
Options: opts,
}, fn); err != nil {
}); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
// Accumulate responses into the final response
var final api.ChatResponse
var r api.ChatResponse
var sb strings.Builder
for resp := range ch {
switch r := resp.(type) {
for rr := range ch {
switch t := rr.(type) {
case api.ChatResponse:
sb.WriteString(r.Message.Content)
final = r
sb.WriteString(t.Message.Content)
r = t
case gin.H:
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
msg, ok := t["error"].(string)
if !ok {
msg = "unexpected error format in response"
}
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
return
}
}
final.Message = api.Message{Role: "assistant", Content: sb.String()}
c.JSON(http.StatusOK, final)
r.Message.Content = sb.String()
c.JSON(http.StatusOK, r)
return
}
streamResponse(c, ch)
}
func handleErrorResponse(c *gin.Context, err error) {
if errors.Is(err, context.Canceled) {
func handleScheduleError(c *gin.Context, err error) {
switch {
case errors.Is(err, context.Canceled):
c.JSON(499, gin.H{"error": "request canceled"})
return
}
if errors.Is(err, ErrMaxQueue) {
case errors.Is(err, ErrMaxQueue):
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}