server: add logprobs and top_logprobs support to Ollama's API (#12899)

Adds logprobs support to Ollama's API including support for Ollama's
OpenAI-compatible API. By specifying the new 'logprobs' boolean parameter
in the API, Ollama will return the log probabilities for each token generated.
'top_logprobs', an integer value can also be specified up to the value 20.
When specified, the API will also provide the number of most likely tokens to
return at each token position

Co-authored-by: Baptiste Jamin <baptiste@crisp.chat>
This commit is contained in:
Baptiste Jamin
2025-11-11 17:49:50 +01:00
committed by GitHub
parent 6df4208836
commit 59241c5bee
13 changed files with 1367 additions and 47 deletions

View File

@@ -183,6 +183,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
return
}
name := model.ParseName(req.Model)
if !name.IsValid() {
// Ideally this is "invalid model name" but we're keeping with
@@ -212,6 +217,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
return
}
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
origModel := req.Model
@@ -502,12 +512,14 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
Shift: req.Shift == nil || *req.Shift,
Truncate: req.Truncate == nil || *req.Truncate,
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
Shift: req.Shift == nil || *req.Shift,
Truncate: req.Truncate == nil || *req.Truncate,
Logprobs: req.Logprobs,
TopLogprobs: req.TopLogprobs,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
Model: req.Model,
@@ -520,6 +532,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
EvalCount: cr.EvalCount,
EvalDuration: cr.EvalDuration,
},
Logprobs: toAPILogprobs(cr.Logprobs),
}
if builtinParser != nil {
@@ -580,6 +593,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
if req.Stream != nil && !*req.Stream {
var r api.GenerateResponse
var allLogprobs []api.Logprob
var sbThinking strings.Builder
var sbContent strings.Builder
for rr := range ch {
@@ -588,6 +602,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
sbThinking.WriteString(t.Thinking)
sbContent.WriteString(t.Response)
r = t
// Accumulate logprobs from all chunks for non-streaming response
if len(t.Logprobs) > 0 {
allLogprobs = append(allLogprobs, t.Logprobs...)
}
case gin.H:
msg, ok := t["error"].(string)
if !ok {
@@ -609,6 +627,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
r.Thinking = sbThinking.String()
r.Response = sbContent.String()
r.Logprobs = allLogprobs
c.JSON(http.StatusOK, r)
return
@@ -1834,6 +1853,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
return
}
name := model.ParseName(req.Model)
if !name.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
@@ -1859,6 +1883,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
if req.TopLogprobs < 0 || req.TopLogprobs > 20 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"})
return
}
// expire the runner
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
s.sched.expireRunner(m)
@@ -2104,12 +2133,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
// sets up new context given parent context per request
ctx, cancel := context.WithCancel(c.Request.Context())
err := r.Completion(ctx, llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: currentFormat,
Options: opts,
Shift: req.Shift == nil || *req.Shift,
Truncate: truncate,
Prompt: prompt,
Images: images,
Format: currentFormat,
Options: opts,
Shift: req.Shift == nil || *req.Shift,
Truncate: truncate,
Logprobs: req.Logprobs,
TopLogprobs: req.TopLogprobs,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
@@ -2122,7 +2153,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
Logprobs: toAPILogprobs(r.Logprobs),
}
if r.Done {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
@@ -2251,6 +2284,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
if req.Stream != nil && !*req.Stream {
var resp api.ChatResponse
var toolCalls []api.ToolCall
var allLogprobs []api.Logprob
var sbThinking strings.Builder
var sbContent strings.Builder
for rr := range ch {
@@ -2262,6 +2296,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
if len(req.Tools) > 0 {
toolCalls = append(toolCalls, t.Message.ToolCalls...)
}
// Accumulate logprobs from all chunks for non-streaming response
if len(t.Logprobs) > 0 {
allLogprobs = append(allLogprobs, t.Logprobs...)
}
case gin.H:
msg, ok := t["error"].(string)
if !ok {
@@ -2283,6 +2321,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
resp.Message.Content = sbContent.String()
resp.Message.Thinking = sbThinking.String()
resp.Logprobs = allLogprobs
if len(toolCalls) > 0 {
resp.Message.ToolCalls = toolCalls