mirror of
https://github.com/ollama/ollama.git
synced 2025-12-08 09:52:38 +01:00
server: add logprobs and top_logprobs support to Ollama's API (#12899)
Adds logprobs support to Ollama's API including support for Ollama's OpenAI-compatible API. By specifying the new 'logprobs' boolean parameter in the API, Ollama will return the log probabilities for each token generated. 'top_logprobs', an integer value can also be specified up to the value 20. When specified, the API will also provide the number of most likely tokens to return at each token position Co-authored-by: Baptiste Jamin <baptiste@crisp.chat>
This commit is contained in:
@@ -381,3 +381,174 @@ func TestAPIShowModel(t *testing.T) {
|
||||
t.Errorf("%s missing modified_at: %#v", modelName, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIGenerateLogprobs(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
if err := PullIfMissing(ctx, client, smol); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
enableLogprobs := true
|
||||
noStream := false
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
logprobs *bool
|
||||
topLogprobs int
|
||||
expectCount int
|
||||
}{
|
||||
{
|
||||
name: "no_logprobs",
|
||||
logprobs: nil,
|
||||
topLogprobs: 0,
|
||||
expectCount: 0,
|
||||
},
|
||||
{
|
||||
name: "logprobs_only",
|
||||
logprobs: &enableLogprobs,
|
||||
topLogprobs: 0,
|
||||
expectCount: 1,
|
||||
},
|
||||
{
|
||||
name: "logprobs_with_top_5",
|
||||
logprobs: &enableLogprobs,
|
||||
topLogprobs: 5,
|
||||
expectCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: "Why is the sky blue?",
|
||||
Stream: &noStream,
|
||||
Logprobs: test.logprobs != nil && *test.logprobs,
|
||||
TopLogprobs: test.topLogprobs,
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_predict": 10,
|
||||
},
|
||||
}
|
||||
|
||||
var response api.GenerateResponse
|
||||
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
|
||||
if resp.Done {
|
||||
response = resp
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("generate failed: %s", err)
|
||||
}
|
||||
|
||||
// Check logprobs based on expectation
|
||||
if test.expectCount == 0 {
|
||||
if len(response.Logprobs) > 0 {
|
||||
t.Errorf("expected no logprobs but got %d", len(response.Logprobs))
|
||||
}
|
||||
} else {
|
||||
if len(response.Logprobs) == 0 {
|
||||
t.Errorf("expected logprobs but got none")
|
||||
}
|
||||
|
||||
// Validate each logprob entry
|
||||
for i, lp := range response.Logprobs {
|
||||
if lp.Token == "" {
|
||||
t.Errorf("logprob[%d] has empty token", i)
|
||||
}
|
||||
if lp.Logprob > 0 {
|
||||
t.Errorf("logprob[%d] has positive logprob %f (should be <= 0)", i, lp.Logprob)
|
||||
}
|
||||
|
||||
// Check top_logprobs if requested
|
||||
if test.topLogprobs > 0 {
|
||||
if len(lp.TopLogprobs) == 0 {
|
||||
t.Errorf("logprob[%d] expected top_logprobs but got none", i)
|
||||
}
|
||||
if len(lp.TopLogprobs) > test.topLogprobs {
|
||||
t.Errorf("logprob[%d] has %d top_logprobs, expected max %d", i, len(lp.TopLogprobs), test.topLogprobs)
|
||||
}
|
||||
|
||||
// Verify top_logprobs are sorted by probability (descending)
|
||||
for j := 1; j < len(lp.TopLogprobs); j++ {
|
||||
if lp.TopLogprobs[j-1].Logprob < lp.TopLogprobs[j].Logprob {
|
||||
t.Errorf("logprob[%d].top_logprobs not sorted: %f < %f", i, lp.TopLogprobs[j-1].Logprob, lp.TopLogprobs[j].Logprob)
|
||||
}
|
||||
}
|
||||
} else if len(lp.TopLogprobs) > 0 {
|
||||
t.Errorf("logprob[%d] has top_logprobs but none were requested", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIChatLogprobs(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
if err := PullIfMissing(ctx, client, smol); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
enableLogprobs := true
|
||||
noStream := false
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Say hello in one word"},
|
||||
},
|
||||
Stream: &noStream,
|
||||
Logprobs: enableLogprobs,
|
||||
TopLogprobs: 3,
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_predict": 5,
|
||||
},
|
||||
}
|
||||
|
||||
var response api.ChatResponse
|
||||
err := client.Chat(ctx, &req, func(resp api.ChatResponse) error {
|
||||
if resp.Done {
|
||||
response = resp
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("chat failed: %s", err)
|
||||
}
|
||||
|
||||
if len(response.Logprobs) == 0 {
|
||||
t.Fatal("expected logprobs in response but got none")
|
||||
}
|
||||
|
||||
t.Logf("received %d logprobs for chat response", len(response.Logprobs))
|
||||
|
||||
for i, lp := range response.Logprobs {
|
||||
if lp.Token == "" {
|
||||
t.Errorf("logprob[%d] has empty token", i)
|
||||
}
|
||||
if lp.Logprob > 0 {
|
||||
t.Errorf("logprob[%d] has positive logprob %f", i, lp.Logprob)
|
||||
}
|
||||
if len(lp.TopLogprobs) == 0 {
|
||||
t.Errorf("logprob[%d] expected top_logprobs but got none", i)
|
||||
}
|
||||
if len(lp.TopLogprobs) > 3 {
|
||||
t.Errorf("logprob[%d] has %d top_logprobs, expected max 3", i, len(lp.TopLogprobs))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user