This commit is contained in:
Michael Yang
2024-06-20 13:45:47 -07:00
parent 9e35d9bbee
commit d02bbebb11
7 changed files with 263 additions and 53 deletions

View File

@@ -265,6 +265,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
r.Response = sb.String()
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
r.ToolCalls = toolCalls
r.Response = ""
}
c.JSON(http.StatusOK, r)
return
}
@@ -1279,6 +1284,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
caps := []Capability{CapabilityCompletion}
if req.Tools != nil {
caps = append(caps, CapabilityTools)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
@@ -1305,7 +1314,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1348,13 +1357,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
}()
if req.Stream != nil && !*req.Stream {
var r api.ChatResponse
var resp api.ChatResponse
var sb strings.Builder
for rr := range ch {
switch t := rr.(type) {
case api.ChatResponse:
sb.WriteString(t.Message.Content)
r = t
resp = t
case gin.H:
msg, ok := t["error"].(string)
if !ok {
@@ -1369,8 +1378,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
}
r.Message.Content = sb.String()
c.JSON(http.StatusOK, r)
resp.Message.Content = sb.String()
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
resp.Message.ToolCalls = toolCalls
resp.Message.Content = ""
}
c.JSON(http.StatusOK, resp)
return
}