server: add logprobs and top_logprobs support to Ollama's API (#12899)

Adds logprobs support to Ollama's API including support for Ollama's
OpenAI-compatible API. By specifying the new 'logprobs' boolean parameter
in the API, Ollama will return the log probabilities for each token generated.
'top_logprobs', an integer value can also be specified up to the value 20.
When specified, the API will also provide the number of most likely tokens to
return at each token position

Co-authored-by: Baptiste Jamin <baptiste@crisp.chat>
This commit is contained in:
Baptiste Jamin
2025-11-11 17:49:50 +01:00
committed by GitHub
parent 6df4208836
commit 59241c5bee
13 changed files with 1367 additions and 47 deletions

View File

@@ -381,3 +381,174 @@ func TestAPIShowModel(t *testing.T) {
t.Errorf("%s missing modified_at: %#v", modelName, resp)
}
}
func TestAPIGenerateLogprobs(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, smol); err != nil {
t.Fatalf("pull failed %s", err)
}
enableLogprobs := true
noStream := false
tests := []struct {
name string
logprobs *bool
topLogprobs int
expectCount int
}{
{
name: "no_logprobs",
logprobs: nil,
topLogprobs: 0,
expectCount: 0,
},
{
name: "logprobs_only",
logprobs: &enableLogprobs,
topLogprobs: 0,
expectCount: 1,
},
{
name: "logprobs_with_top_5",
logprobs: &enableLogprobs,
topLogprobs: 5,
expectCount: 1,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
req := api.GenerateRequest{
Model: smol,
Prompt: "Why is the sky blue?",
Stream: &noStream,
Logprobs: test.logprobs != nil && *test.logprobs,
TopLogprobs: test.topLogprobs,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
"num_predict": 10,
},
}
var response api.GenerateResponse
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
if resp.Done {
response = resp
}
return nil
})
if err != nil {
t.Fatalf("generate failed: %s", err)
}
// Check logprobs based on expectation
if test.expectCount == 0 {
if len(response.Logprobs) > 0 {
t.Errorf("expected no logprobs but got %d", len(response.Logprobs))
}
} else {
if len(response.Logprobs) == 0 {
t.Errorf("expected logprobs but got none")
}
// Validate each logprob entry
for i, lp := range response.Logprobs {
if lp.Token == "" {
t.Errorf("logprob[%d] has empty token", i)
}
if lp.Logprob > 0 {
t.Errorf("logprob[%d] has positive logprob %f (should be <= 0)", i, lp.Logprob)
}
// Check top_logprobs if requested
if test.topLogprobs > 0 {
if len(lp.TopLogprobs) == 0 {
t.Errorf("logprob[%d] expected top_logprobs but got none", i)
}
if len(lp.TopLogprobs) > test.topLogprobs {
t.Errorf("logprob[%d] has %d top_logprobs, expected max %d", i, len(lp.TopLogprobs), test.topLogprobs)
}
// Verify top_logprobs are sorted by probability (descending)
for j := 1; j < len(lp.TopLogprobs); j++ {
if lp.TopLogprobs[j-1].Logprob < lp.TopLogprobs[j].Logprob {
t.Errorf("logprob[%d].top_logprobs not sorted: %f < %f", i, lp.TopLogprobs[j-1].Logprob, lp.TopLogprobs[j].Logprob)
}
}
} else if len(lp.TopLogprobs) > 0 {
t.Errorf("logprob[%d] has top_logprobs but none were requested", i)
}
}
}
})
}
}
func TestAPIChatLogprobs(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, smol); err != nil {
t.Fatalf("pull failed %s", err)
}
enableLogprobs := true
noStream := false
req := api.ChatRequest{
Model: smol,
Messages: []api.Message{
{Role: "user", Content: "Say hello in one word"},
},
Stream: &noStream,
Logprobs: enableLogprobs,
TopLogprobs: 3,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
"num_predict": 5,
},
}
var response api.ChatResponse
err := client.Chat(ctx, &req, func(resp api.ChatResponse) error {
if resp.Done {
response = resp
}
return nil
})
if err != nil {
t.Fatalf("chat failed: %s", err)
}
if len(response.Logprobs) == 0 {
t.Fatal("expected logprobs in response but got none")
}
t.Logf("received %d logprobs for chat response", len(response.Logprobs))
for i, lp := range response.Logprobs {
if lp.Token == "" {
t.Errorf("logprob[%d] has empty token", i)
}
if lp.Logprob > 0 {
t.Errorf("logprob[%d] has positive logprob %f", i, lp.Logprob)
}
if len(lp.TopLogprobs) == 0 {
t.Errorf("logprob[%d] expected top_logprobs but got none", i)
}
if len(lp.TopLogprobs) > 3 {
t.Errorf("logprob[%d] has %d top_logprobs, expected max 3", i, len(lp.TopLogprobs))
}
}
}