mirror of
https://github.com/ollama/ollama.git
synced 2025-12-09 10:02:00 +01:00
server: add logprobs and top_logprobs support to Ollama's API (#12899)
Adds logprobs support to Ollama's API including support for Ollama's OpenAI-compatible API. By specifying the new 'logprobs' boolean parameter in the API, Ollama will return the log probabilities for each token generated. 'top_logprobs', an integer value can also be specified up to the value 20. When specified, the API will also provide the number of most likely tokens to return at each token position Co-authored-by: Baptiste Jamin <baptiste@crisp.chat>
This commit is contained in:
@@ -183,6 +183,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
// Ideally this is "invalid model name" but we're keeping with
|
||||
@@ -212,6 +217,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
|
||||
return
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
origModel := req.Model
|
||||
|
||||
@@ -502,12 +512,14 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
var sb strings.Builder
|
||||
defer close(ch)
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
Shift: req.Shift == nil || *req.Shift,
|
||||
Truncate: req.Truncate == nil || *req.Truncate,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
Shift: req.Shift == nil || *req.Shift,
|
||||
Truncate: req.Truncate == nil || *req.Truncate,
|
||||
Logprobs: req.Logprobs,
|
||||
TopLogprobs: req.TopLogprobs,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
res := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
@@ -520,6 +532,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
EvalCount: cr.EvalCount,
|
||||
EvalDuration: cr.EvalDuration,
|
||||
},
|
||||
Logprobs: toAPILogprobs(cr.Logprobs),
|
||||
}
|
||||
|
||||
if builtinParser != nil {
|
||||
@@ -580,6 +593,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
var r api.GenerateResponse
|
||||
var allLogprobs []api.Logprob
|
||||
var sbThinking strings.Builder
|
||||
var sbContent strings.Builder
|
||||
for rr := range ch {
|
||||
@@ -588,6 +602,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
sbThinking.WriteString(t.Thinking)
|
||||
sbContent.WriteString(t.Response)
|
||||
r = t
|
||||
// Accumulate logprobs from all chunks for non-streaming response
|
||||
if len(t.Logprobs) > 0 {
|
||||
allLogprobs = append(allLogprobs, t.Logprobs...)
|
||||
}
|
||||
case gin.H:
|
||||
msg, ok := t["error"].(string)
|
||||
if !ok {
|
||||
@@ -609,6 +627,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
r.Thinking = sbThinking.String()
|
||||
r.Response = sbContent.String()
|
||||
r.Logprobs = allLogprobs
|
||||
|
||||
c.JSON(http.StatusOK, r)
|
||||
return
|
||||
@@ -1834,6 +1853,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
@@ -1859,6 +1883,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
|
||||
return
|
||||
}
|
||||
|
||||
// expire the runner
|
||||
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||
s.sched.expireRunner(m)
|
||||
@@ -2104,12 +2133,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
// sets up new context given parent context per request
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
err := r.Completion(ctx, llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: currentFormat,
|
||||
Options: opts,
|
||||
Shift: req.Shift == nil || *req.Shift,
|
||||
Truncate: truncate,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: currentFormat,
|
||||
Options: opts,
|
||||
Shift: req.Shift == nil || *req.Shift,
|
||||
Truncate: truncate,
|
||||
Logprobs: req.Logprobs,
|
||||
TopLogprobs: req.TopLogprobs,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
res := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
@@ -2122,7 +2153,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
EvalCount: r.EvalCount,
|
||||
EvalDuration: r.EvalDuration,
|
||||
},
|
||||
Logprobs: toAPILogprobs(r.Logprobs),
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
@@ -2251,6 +2284,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
var resp api.ChatResponse
|
||||
var toolCalls []api.ToolCall
|
||||
var allLogprobs []api.Logprob
|
||||
var sbThinking strings.Builder
|
||||
var sbContent strings.Builder
|
||||
for rr := range ch {
|
||||
@@ -2262,6 +2296,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
if len(req.Tools) > 0 {
|
||||
toolCalls = append(toolCalls, t.Message.ToolCalls...)
|
||||
}
|
||||
// Accumulate logprobs from all chunks for non-streaming response
|
||||
if len(t.Logprobs) > 0 {
|
||||
allLogprobs = append(allLogprobs, t.Logprobs...)
|
||||
}
|
||||
case gin.H:
|
||||
msg, ok := t["error"].(string)
|
||||
if !ok {
|
||||
@@ -2283,6 +2321,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
|
||||
resp.Message.Content = sbContent.String()
|
||||
resp.Message.Thinking = sbThinking.String()
|
||||
resp.Logprobs = allLogprobs
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
resp.Message.ToolCalls = toolCalls
|
||||
|
||||
Reference in New Issue
Block a user