This commit is contained in:
Michael Yang
2024-06-20 11:00:08 -07:00
parent 269ed6e6a2
commit 2c3fe1fd97
5 changed files with 224 additions and 113 deletions

View File

@@ -54,6 +54,8 @@ func init() {
gin.SetMode(mode)
}
var errRequired = errors.New("is required")
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
@@ -69,7 +71,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (*runnerRef, error) {
if name == "" {
return nil, errors.New("model is required")
return nil, fmt.Errorf("model %w", errRequired)
}
model, err := GetModel(name)
@@ -121,7 +123,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
return
} else if err != nil {
handleScheduleError(c, err)
handleScheduleError(c, req.Model, err)
return
}
if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: true,
DoneReason: "load",
})
return
}
@@ -139,23 +151,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
msgs = append(msgs, api.Message{Role: "system", Content: r.model.System})
}
if req.Prompt != "" {
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
if len(msgs) == 0 {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: true,
DoneReason: "load",
})
return
}
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
tmpl := r.model.Template
if req.Template != "" {
@@ -256,7 +256,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
r, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, err)
handleScheduleError(c, req.Model, err)
return
}
@@ -1135,7 +1135,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
return
} else if err != nil {
handleScheduleError(c, err)
handleScheduleError(c, req.Model, err)
return
}
@@ -1150,7 +1150,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
prompt, images, err := chatPrompt(c.Request.Context(), r, req.Messages)
prompt, images, err := chatPrompt(c.Request.Context(), r.model, r.llama.Tokenize, r.Options, req.Messages)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1215,12 +1215,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
streamResponse(c, ch)
}
func handleScheduleError(c *gin.Context, err error) {
func handleScheduleError(c *gin.Context, name string, err error) {
switch {
case errors.Is(err, errRequired):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, context.Canceled):
c.JSON(499, gin.H{"error": "request canceled"})
case errors.Is(err, ErrMaxQueue):
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
case errors.Is(err, os.ErrNotExist):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}