diff --git a/api/types.go b/api/types.go index 1615fce6f7..d5788d5419 100644 --- a/api/types.go +++ b/api/types.go @@ -117,6 +117,14 @@ type GenerateRequest struct { // DebugRenderOnly is a debug option that, when set to true, returns the rendered // template instead of calling the model. DebugRenderOnly bool `json:"_debug_render_only,omitempty"` + + // Logprobs specifies whether to return log probabilities of the output tokens. + Logprobs bool `json:"logprobs,omitempty"` + + // TopLogprobs is the number of most likely tokens to return at each token position, + // each with an associated log probability. Only applies when Logprobs is true. + // Valid values are 0-20. Default is 0 (only return the selected token's logprob). + TopLogprobs int `json:"top_logprobs,omitempty"` } // ChatRequest describes a request sent by [Client.Chat]. @@ -159,6 +167,14 @@ type ChatRequest struct { // DebugRenderOnly is a debug option that, when set to true, returns the rendered // template instead of calling the model. DebugRenderOnly bool `json:"_debug_render_only,omitempty"` + + // Logprobs specifies whether to return log probabilities of the output tokens. + Logprobs bool `json:"logprobs,omitempty"` + + // TopLogprobs is the number of most likely tokens to return at each token position, + // each with an associated log probability. Only applies when Logprobs is true. + // Valid values are 0-20. Default is 0 (only return the selected token's logprob). + TopLogprobs int `json:"top_logprobs,omitempty"` } type Tools []Tool @@ -343,6 +359,24 @@ func (t *ToolFunction) String() string { return string(bts) } +// TokenLogprob represents log probability information for a single token alternative. +type TokenLogprob struct { + // Token is the text representation of the token. + Token string `json:"token"` + + // Logprob is the log probability of this token. + Logprob float64 `json:"logprob"` +} + +// Logprob contains log probability information for a generated token. +type Logprob struct { + TokenLogprob + + // TopLogprobs contains the most likely tokens and their log probabilities + // at this position, if requested via TopLogprobs parameter. + TopLogprobs []TokenLogprob `json:"top_logprobs,omitempty"` +} + // ChatResponse is the response returned by [Client.Chat]. Its fields are // similar to [GenerateResponse]. type ChatResponse struct { @@ -369,6 +403,10 @@ type ChatResponse struct { DebugInfo *DebugInfo `json:"_debug_info,omitempty"` + // Logprobs contains log probability information for the generated tokens, + // if requested via the Logprobs parameter. + Logprobs []Logprob `json:"logprobs,omitempty"` + Metrics } @@ -677,6 +715,10 @@ type GenerateResponse struct { ToolCalls []ToolCall `json:"tool_calls,omitempty"` DebugInfo *DebugInfo `json:"_debug_info,omitempty"` + + // Logprobs contains log probability information for the generated tokens, + // if requested via the Logprobs parameter. + Logprobs []Logprob `json:"logprobs,omitempty"` } // ModelDetails provides details about a model. diff --git a/integration/api_test.go b/integration/api_test.go index 39eea39c06..839e14d7c2 100644 --- a/integration/api_test.go +++ b/integration/api_test.go @@ -381,3 +381,174 @@ func TestAPIShowModel(t *testing.T) { t.Errorf("%s missing modified_at: %#v", modelName, resp) } } + +func TestAPIGenerateLogprobs(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + if err := PullIfMissing(ctx, client, smol); err != nil { + t.Fatalf("pull failed %s", err) + } + + enableLogprobs := true + noStream := false + + tests := []struct { + name string + logprobs *bool + topLogprobs int + expectCount int + }{ + { + name: "no_logprobs", + logprobs: nil, + topLogprobs: 0, + expectCount: 0, + }, + { + name: "logprobs_only", + logprobs: &enableLogprobs, + topLogprobs: 0, + expectCount: 1, + }, + { + name: "logprobs_with_top_5", + logprobs: &enableLogprobs, + topLogprobs: 5, + expectCount: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + req := api.GenerateRequest{ + Model: smol, + Prompt: "Why is the sky blue?", + Stream: &noStream, + Logprobs: test.logprobs != nil && *test.logprobs, + TopLogprobs: test.topLogprobs, + Options: map[string]interface{}{ + "temperature": 0, + "seed": 123, + "num_predict": 10, + }, + } + + var response api.GenerateResponse + err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error { + if resp.Done { + response = resp + } + return nil + }) + if err != nil { + t.Fatalf("generate failed: %s", err) + } + + // Check logprobs based on expectation + if test.expectCount == 0 { + if len(response.Logprobs) > 0 { + t.Errorf("expected no logprobs but got %d", len(response.Logprobs)) + } + } else { + if len(response.Logprobs) == 0 { + t.Errorf("expected logprobs but got none") + } + + // Validate each logprob entry + for i, lp := range response.Logprobs { + if lp.Token == "" { + t.Errorf("logprob[%d] has empty token", i) + } + if lp.Logprob > 0 { + t.Errorf("logprob[%d] has positive logprob %f (should be <= 0)", i, lp.Logprob) + } + + // Check top_logprobs if requested + if test.topLogprobs > 0 { + if len(lp.TopLogprobs) == 0 { + t.Errorf("logprob[%d] expected top_logprobs but got none", i) + } + if len(lp.TopLogprobs) > test.topLogprobs { + t.Errorf("logprob[%d] has %d top_logprobs, expected max %d", i, len(lp.TopLogprobs), test.topLogprobs) + } + + // Verify top_logprobs are sorted by probability (descending) + for j := 1; j < len(lp.TopLogprobs); j++ { + if lp.TopLogprobs[j-1].Logprob < lp.TopLogprobs[j].Logprob { + t.Errorf("logprob[%d].top_logprobs not sorted: %f < %f", i, lp.TopLogprobs[j-1].Logprob, lp.TopLogprobs[j].Logprob) + } + } + } else if len(lp.TopLogprobs) > 0 { + t.Errorf("logprob[%d] has top_logprobs but none were requested", i) + } + } + } + }) + } +} + +func TestAPIChatLogprobs(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + if err := PullIfMissing(ctx, client, smol); err != nil { + t.Fatalf("pull failed %s", err) + } + + enableLogprobs := true + noStream := false + + req := api.ChatRequest{ + Model: smol, + Messages: []api.Message{ + {Role: "user", Content: "Say hello in one word"}, + }, + Stream: &noStream, + Logprobs: enableLogprobs, + TopLogprobs: 3, + Options: map[string]interface{}{ + "temperature": 0, + "seed": 123, + "num_predict": 5, + }, + } + + var response api.ChatResponse + err := client.Chat(ctx, &req, func(resp api.ChatResponse) error { + if resp.Done { + response = resp + } + return nil + }) + if err != nil { + t.Fatalf("chat failed: %s", err) + } + + if len(response.Logprobs) == 0 { + t.Fatal("expected logprobs in response but got none") + } + + t.Logf("received %d logprobs for chat response", len(response.Logprobs)) + + for i, lp := range response.Logprobs { + if lp.Token == "" { + t.Errorf("logprob[%d] has empty token", i) + } + if lp.Logprob > 0 { + t.Errorf("logprob[%d] has positive logprob %f", i, lp.Logprob) + } + if len(lp.TopLogprobs) == 0 { + t.Errorf("logprob[%d] expected top_logprobs but got none", i) + } + if len(lp.TopLogprobs) > 3 { + t.Errorf("logprob[%d] has %d top_logprobs, expected max 3", i, len(lp.TopLogprobs)) + } + } +} diff --git a/llama/llama.go b/llama/llama.go index c995b3ead2..f8a051ea25 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -217,6 +217,19 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 { return embeddings } +// GetLogitsIth gets the logits for the ith token +func (c *Context) GetLogitsIth(i int) []float32 { + logits := unsafe.Pointer(C.llama_get_logits_ith(c.c, C.int32_t(i))) + if logits == nil { + return nil + } + + vocabSize := c.Model().NumVocab() + result := make([]float32, vocabSize) + _ = copy(result, unsafe.Slice((*float32)(logits), vocabSize)) + return result +} + type ModelParams struct { NumGpuLayers int MainGpu int diff --git a/llm/server.go b/llm/server.go index 4f7c3760d9..87f97a010f 100644 --- a/llm/server.go +++ b/llm/server.go @@ -1362,6 +1362,12 @@ type CompletionRequest struct { Grammar string // set before sending the request to the subprocess Shift bool Truncate bool + + // Logprobs specifies whether to include log probabilities in the response + Logprobs bool + + // TopLogprobs specifies the number of most likely alternative tokens to return (0-20) + TopLogprobs int } // DoneReason represents the reason why a completion response is done @@ -1387,6 +1393,18 @@ func (d DoneReason) String() string { } } +// TokenLogprob represents log probability information for a single token alternative. +type TokenLogprob struct { + Token string `json:"token"` + Logprob float64 `json:"logprob"` +} + +// Logprob contains log probability information for a generated token. +type Logprob struct { + TokenLogprob + TopLogprobs []TokenLogprob `json:"top_logprobs,omitempty"` +} + type CompletionResponse struct { Content string `json:"content"` DoneReason DoneReason `json:"done_reason"` @@ -1395,6 +1413,9 @@ type CompletionResponse struct { PromptEvalDuration time.Duration `json:"prompt_eval_duration"` EvalCount int `json:"eval_count"` EvalDuration time.Duration `json:"eval_duration"` + + // Logprobs contains log probability information if requested + Logprobs []Logprob `json:"logprobs,omitempty"` } func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { @@ -1530,7 +1551,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu if c.Content != "" { fn(CompletionResponse{ - Content: c.Content, + Content: c.Content, + Logprobs: c.Logprobs, }) } diff --git a/openai/openai.go b/openai/openai.go index d4fd26c251..4713d481b5 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -40,22 +40,29 @@ type Message struct { ToolCallID string `json:"tool_call_id,omitempty"` } +type ChoiceLogprobs struct { + Content []api.Logprob `json:"content"` +} + type Choice struct { - Index int `json:"index"` - Message Message `json:"message"` - FinishReason *string `json:"finish_reason"` + Index int `json:"index"` + Message Message `json:"message"` + FinishReason *string `json:"finish_reason"` + Logprobs *ChoiceLogprobs `json:"logprobs,omitempty"` } type ChunkChoice struct { - Index int `json:"index"` - Delta Message `json:"delta"` - FinishReason *string `json:"finish_reason"` + Index int `json:"index"` + Delta Message `json:"delta"` + FinishReason *string `json:"finish_reason"` + Logprobs *ChoiceLogprobs `json:"logprobs,omitempty"` } type CompleteChunkChoice struct { - Text string `json:"text"` - Index int `json:"index"` - FinishReason *string `json:"finish_reason"` + Text string `json:"text"` + Index int `json:"index"` + FinishReason *string `json:"finish_reason"` + Logprobs *ChoiceLogprobs `json:"logprobs,omitempty"` } type Usage struct { @@ -104,6 +111,8 @@ type ChatCompletionRequest struct { Tools []api.Tool `json:"tools"` Reasoning *Reasoning `json:"reasoning,omitempty"` ReasoningEffort *string `json:"reasoning_effort,omitempty"` + Logprobs *bool `json:"logprobs"` + TopLogprobs int `json:"top_logprobs"` DebugRenderOnly bool `json:"_debug_render_only"` } @@ -142,6 +151,7 @@ type CompletionRequest struct { Temperature *float32 `json:"temperature"` TopP float32 `json:"top_p"` Suffix string `json:"suffix"` + Logprobs *int `json:"logprobs"` DebugRenderOnly bool `json:"_debug_render_only"` } @@ -251,6 +261,12 @@ func ToToolCalls(tc []api.ToolCall) []ToolCall { // ToChatCompletion converts an api.ChatResponse to ChatCompletion func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion { toolCalls := ToToolCalls(r.Message.ToolCalls) + + var logprobs *ChoiceLogprobs + if len(r.Logprobs) > 0 { + logprobs = &ChoiceLogprobs{Content: r.Logprobs} + } + return ChatCompletion{ Id: id, Object: "chat.completion", @@ -269,6 +285,7 @@ func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion { } return nil }(r.DoneReason), + Logprobs: logprobs, }}, Usage: ToUsage(r), DebugInfo: r.DebugInfo, } @@ -277,6 +294,12 @@ func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion { // ToChunk converts an api.ChatResponse to ChatCompletionChunk func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk { toolCalls := ToToolCalls(r.Message.ToolCalls) + + var logprobs *ChoiceLogprobs + if len(r.Logprobs) > 0 { + logprobs = &ChoiceLogprobs{Content: r.Logprobs} + } + return ChatCompletionChunk{ Id: id, Object: "chat.completion.chunk", @@ -295,6 +318,7 @@ func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu } return nil }(r.DoneReason), + Logprobs: logprobs, }}, } } @@ -604,6 +628,8 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { Stream: &r.Stream, Tools: r.Tools, Think: think, + Logprobs: r.Logprobs != nil && *r.Logprobs, + TopLogprobs: r.TopLogprobs, DebugRenderOnly: r.DebugRenderOnly, }, nil } @@ -680,12 +706,21 @@ func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { options["top_p"] = 1.0 } + var logprobs bool + var topLogprobs int + if r.Logprobs != nil && *r.Logprobs > 0 { + logprobs = true + topLogprobs = *r.Logprobs + } + return api.GenerateRequest{ Model: r.Model, Prompt: r.Prompt, Options: options, Stream: &r.Stream, Suffix: r.Suffix, + Logprobs: logprobs, + TopLogprobs: topLogprobs, DebugRenderOnly: r.DebugRenderOnly, }, nil } diff --git a/openai/openai_test.go b/openai/openai_test.go index 6a42f91549..51e243dec9 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -3,6 +3,7 @@ package openai import ( "encoding/base64" "testing" + "time" "github.com/google/go-cmp/cmp" @@ -218,3 +219,218 @@ func TestToToolCallsPreservesIDs(t *testing.T) { t.Errorf("input tool calls mutated (-want +got):\n%s", diff) } } + +func TestFromChatRequest_WithLogprobs(t *testing.T) { + trueVal := true + + req := ChatCompletionRequest{ + Model: "test-model", + Messages: []Message{ + {Role: "user", Content: "Hello"}, + }, + Logprobs: &trueVal, + TopLogprobs: 5, + } + + result, err := FromChatRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !result.Logprobs { + t.Error("expected Logprobs to be true") + } + + if result.TopLogprobs != 5 { + t.Errorf("expected TopLogprobs to be 5, got %d", result.TopLogprobs) + } +} + +func TestFromChatRequest_LogprobsDefault(t *testing.T) { + req := ChatCompletionRequest{ + Model: "test-model", + Messages: []Message{ + {Role: "user", Content: "Hello"}, + }, + } + + result, err := FromChatRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Logprobs { + t.Error("expected Logprobs to be false by default") + } + + if result.TopLogprobs != 0 { + t.Errorf("expected TopLogprobs to be 0 by default, got %d", result.TopLogprobs) + } +} + +func TestFromCompleteRequest_WithLogprobs(t *testing.T) { + logprobsVal := 5 + + req := CompletionRequest{ + Model: "test-model", + Prompt: "Hello", + Logprobs: &logprobsVal, + } + + result, err := FromCompleteRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !result.Logprobs { + t.Error("expected Logprobs to be true") + } + + if result.TopLogprobs != 5 { + t.Errorf("expected TopLogprobs to be 5, got %d", result.TopLogprobs) + } +} + +func TestToChatCompletion_WithLogprobs(t *testing.T) { + createdAt := time.Unix(1234567890, 0) + resp := api.ChatResponse{ + Model: "test-model", + CreatedAt: createdAt, + Message: api.Message{Role: "assistant", Content: "Hello there"}, + Logprobs: []api.Logprob{ + { + TokenLogprob: api.TokenLogprob{ + Token: "Hello", + Logprob: -0.5, + }, + TopLogprobs: []api.TokenLogprob{ + {Token: "Hello", Logprob: -0.5}, + {Token: "Hi", Logprob: -1.2}, + }, + }, + { + TokenLogprob: api.TokenLogprob{ + Token: " there", + Logprob: -0.3, + }, + TopLogprobs: []api.TokenLogprob{ + {Token: " there", Logprob: -0.3}, + {Token: " world", Logprob: -1.5}, + }, + }, + }, + Done: true, + Metrics: api.Metrics{ + PromptEvalCount: 5, + EvalCount: 2, + }, + } + + id := "test-id" + + result := ToChatCompletion(id, resp) + + if result.Id != id { + t.Errorf("expected Id %q, got %q", id, result.Id) + } + + if result.Created != 1234567890 { + t.Errorf("expected Created %d, got %d", int64(1234567890), result.Created) + } + + if len(result.Choices) != 1 { + t.Fatalf("expected 1 choice, got %d", len(result.Choices)) + } + + choice := result.Choices[0] + if choice.Message.Content != "Hello there" { + t.Errorf("expected content %q, got %q", "Hello there", choice.Message.Content) + } + + if choice.Logprobs == nil { + t.Fatal("expected Logprobs to be present") + } + + if len(choice.Logprobs.Content) != 2 { + t.Fatalf("expected 2 logprobs, got %d", len(choice.Logprobs.Content)) + } + + // Verify first logprob + if choice.Logprobs.Content[0].Token != "Hello" { + t.Errorf("expected first token %q, got %q", "Hello", choice.Logprobs.Content[0].Token) + } + if choice.Logprobs.Content[0].Logprob != -0.5 { + t.Errorf("expected first logprob -0.5, got %f", choice.Logprobs.Content[0].Logprob) + } + if len(choice.Logprobs.Content[0].TopLogprobs) != 2 { + t.Errorf("expected 2 top_logprobs, got %d", len(choice.Logprobs.Content[0].TopLogprobs)) + } + + // Verify second logprob + if choice.Logprobs.Content[1].Token != " there" { + t.Errorf("expected second token %q, got %q", " there", choice.Logprobs.Content[1].Token) + } +} + +func TestToChatCompletion_WithoutLogprobs(t *testing.T) { + createdAt := time.Unix(1234567890, 0) + resp := api.ChatResponse{ + Model: "test-model", + CreatedAt: createdAt, + Message: api.Message{Role: "assistant", Content: "Hello"}, + Done: true, + Metrics: api.Metrics{ + PromptEvalCount: 5, + EvalCount: 1, + }, + } + + id := "test-id" + + result := ToChatCompletion(id, resp) + + if len(result.Choices) != 1 { + t.Fatalf("expected 1 choice, got %d", len(result.Choices)) + } + + // When no logprobs, Logprobs should be nil + if result.Choices[0].Logprobs != nil { + t.Error("expected Logprobs to be nil when not requested") + } +} + +func TestFromChatRequest_TopLogprobsRange(t *testing.T) { + tests := []struct { + name string + topLogprobs int + expectValid bool + }{ + {name: "valid: 0", topLogprobs: 0, expectValid: true}, + {name: "valid: 1", topLogprobs: 1, expectValid: true}, + {name: "valid: 10", topLogprobs: 10, expectValid: true}, + {name: "valid: 20", topLogprobs: 20, expectValid: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + trueVal := true + req := ChatCompletionRequest{ + Model: "test-model", + Messages: []Message{ + {Role: "user", Content: "Hello"}, + }, + Logprobs: &trueVal, + TopLogprobs: tt.topLogprobs, + } + + result, err := FromChatRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.TopLogprobs != tt.topLogprobs { + t.Errorf("expected TopLogprobs %d, got %d", tt.topLogprobs, result.TopLogprobs) + } + }) + } +} diff --git a/runner/common/logprob.go b/runner/common/logprob.go new file mode 100644 index 0000000000..a0d764a36d --- /dev/null +++ b/runner/common/logprob.go @@ -0,0 +1,79 @@ +package common + +import ( + "math" + "sort" + + "github.com/ollama/ollama/llm" +) + +// TokenDecoderFunc is a function that converts token IDs to text. +type TokenDecoderFunc func(tokenID int) string + +// CalculateLogprobs converts raw logits to log probabilities and finds top K tokens. +// It uses numerically stable softmax to compute log probabilities. +func CalculateLogprobs(logits []float32, selectedToken int, topK int, decoder TokenDecoderFunc) []llm.Logprob { + if len(logits) == 0 { + return nil + } + + // Step 1: Convert logits to log probabilities using numerically stable softmax + maxLogit := logits[0] + for _, logit := range logits[1:] { + if logit > maxLogit { + maxLogit = logit + } + } + + var sumExp float64 + for _, logit := range logits { + sumExp += math.Exp(float64(logit - maxLogit)) + } + logSumExp := float32(math.Log(sumExp)) + + logProbs := make([]float32, len(logits)) + for i, logit := range logits { + logProbs[i] = (logit - maxLogit) - logSumExp + } + + // Step 2: Get selected token's information + selectedLogprob := logProbs[selectedToken] + selectedText := decoder(selectedToken) + + result := llm.Logprob{ + TokenLogprob: llm.TokenLogprob{ + Token: selectedText, + Logprob: float64(selectedLogprob), + }, + } + + // Step 3: If topK requested, find the top K tokens + if topK > 0 { + type tokenLogprobPair struct { + tokenID int + logprob float32 + } + + pairs := make([]tokenLogprobPair, len(logProbs)) + for i, lp := range logProbs { + pairs[i] = tokenLogprobPair{tokenID: i, logprob: lp} + } + + sort.Slice(pairs, func(i, j int) bool { + return pairs[i].logprob > pairs[j].logprob + }) + + k := min(topK, len(pairs)) + topLogprobs := make([]llm.TokenLogprob, k) + for i := range k { + tokenText := decoder(pairs[i].tokenID) + topLogprobs[i] = llm.TokenLogprob{ + Token: tokenText, + Logprob: float64(pairs[i].logprob), + } + } + result.TopLogprobs = topLogprobs + } + + return []llm.Logprob{result} +} diff --git a/runner/common/logprob_test.go b/runner/common/logprob_test.go new file mode 100644 index 0000000000..c798f3f4b4 --- /dev/null +++ b/runner/common/logprob_test.go @@ -0,0 +1,498 @@ +package common + +import ( + "math" + "testing" + + "github.com/ollama/ollama/llm" +) + +func TestCalculateLogprobs(t *testing.T) { + tokens := map[int]string{ + 0: "hello", + 1: "hi", + 2: "hey", + 3: "world", + } + decoder := func(tokenID int) string { + if text, ok := tokens[tokenID]; ok { + return text + } + return "" + } + + tests := []struct { + name string + logits []float32 + selectedToken int + topK int + wantLen int + wantToken string + }{ + { + name: "Empty logits", + logits: []float32{}, + selectedToken: 0, + topK: 0, + wantLen: 0, + }, + { + name: "Single token without top logprobs", + logits: []float32{1.0, 0.5, 0.3, 0.1}, + selectedToken: 0, + topK: 0, + wantLen: 1, + wantToken: "hello", + }, + { + name: "Single token with top logprobs", + logits: []float32{1.0, 0.5, 0.3, 0.1}, + selectedToken: 0, + topK: 3, + wantLen: 1, + wantToken: "hello", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := CalculateLogprobs(tt.logits, tt.selectedToken, tt.topK, decoder) + if len(result) != tt.wantLen { + t.Errorf("CalculateLogprobs() returned %d results, want %d", len(result), tt.wantLen) + } + if tt.wantLen > 0 && result[0].Token != tt.wantToken { + t.Errorf("CalculateLogprobs() token = %s, want %s", result[0].Token, tt.wantToken) + } + if tt.topK > 0 && len(result) > 0 { + if len(result[0].TopLogprobs) != tt.topK { + t.Errorf("CalculateLogprobs() top logprobs count = %d, want %d", len(result[0].TopLogprobs), tt.topK) + } + } + }) + } +} + +func TestCalculateLogprobsNumericalStability(t *testing.T) { + tokens := map[int]string{ + 0: "a", + 1: "b", + 2: "c", + } + decoder := func(tokenID int) string { + if text, ok := tokens[tokenID]; ok { + return text + } + return "" + } + + // Test with very large logits to ensure numerical stability + logits := []float32{1000.0, 999.0, 998.0} + result := CalculateLogprobs(logits, 0, 3, decoder) + + if len(result) != 1 { + t.Fatalf("Expected 1 result, got %d", len(result)) + } + + // Check that log probabilities are finite and reasonable + if math.IsInf(result[0].Logprob, 0) || math.IsNaN(result[0].Logprob) { + t.Errorf("Selected token logprob is not finite: %f", result[0].Logprob) + } + + for i, tlp := range result[0].TopLogprobs { + if math.IsInf(tlp.Logprob, 0) || math.IsNaN(tlp.Logprob) { + t.Errorf("Top logprob[%d] is not finite: %f", i, tlp.Logprob) + } + } + + // Top logprobs should be in descending order + for i := 1; i < len(result[0].TopLogprobs); i++ { + if result[0].TopLogprobs[i].Logprob > result[0].TopLogprobs[i-1].Logprob { + t.Errorf("Top logprobs not in descending order: %f > %f", + result[0].TopLogprobs[i].Logprob, result[0].TopLogprobs[i-1].Logprob) + } + } +} + +func TestCalculateLogprobsProbabilityCorrectness(t *testing.T) { + tokens := map[int]string{ + 0: "hello", + 1: "world", + 2: "foo", + 3: "bar", + } + decoder := func(tokenID int) string { + if text, ok := tokens[tokenID]; ok { + return text + } + return "" + } + + tests := []struct { + name string + logits []float32 + selectedToken int + topK int + }{ + { + name: "Uniform logits", + logits: []float32{1.0, 1.0, 1.0, 1.0}, + selectedToken: 0, + topK: 4, + }, + { + name: "Different logits", + logits: []float32{2.0, 1.0, 0.5, 0.1}, + selectedToken: 0, + topK: 4, + }, + { + name: "Negative logits", + logits: []float32{-1.0, -2.0, -3.0, -4.0}, + selectedToken: 0, + topK: 4, + }, + { + name: "Mixed logits", + logits: []float32{5.0, -5.0, 0.0, 2.5}, + selectedToken: 0, + topK: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := CalculateLogprobs(tt.logits, tt.selectedToken, tt.topK, decoder) + + if len(result) != 1 { + t.Fatalf("Expected 1 result, got %d", len(result)) + } + + // Verify all probabilities are non-positive (log probabilities should be <= 0) + if result[0].Logprob > 0 { + t.Errorf("Selected token logprob should be <= 0, got %f", result[0].Logprob) + } + + for i, tlp := range result[0].TopLogprobs { + if tlp.Logprob > 0 { + t.Errorf("Top logprob[%d] should be <= 0, got %f", i, tlp.Logprob) + } + } + + // Verify that probabilities sum to approximately 1 + // Sum of exp(logprob) for all tokens should equal 1 + var probSum float64 + for _, lp := range result[0].TopLogprobs { + probSum += math.Exp(lp.Logprob) + } + + // For uniform logits, each probability should be 1/n + if tt.name == "Uniform logits" { + expectedProb := 1.0 / float64(len(tt.logits)) + actualProb := math.Exp(result[0].Logprob) + if math.Abs(actualProb-expectedProb) > 1e-6 { + t.Errorf("For uniform logits, expected probability %f, got %f", + expectedProb, actualProb) + } + } + + // Verify top logprobs are sorted in descending order + for i := 1; i < len(result[0].TopLogprobs); i++ { + if result[0].TopLogprobs[i].Logprob > result[0].TopLogprobs[i-1].Logprob { + t.Errorf("Top logprobs not sorted: position %d (%f) > position %d (%f)", + i, result[0].TopLogprobs[i].Logprob, + i-1, result[0].TopLogprobs[i-1].Logprob) + } + } + + // Verify the selected token appears in top logprobs + selectedText := decoder(tt.selectedToken) + found := false + for _, tlp := range result[0].TopLogprobs { + if tlp.Token == selectedText { + found = true + // The logprob in top logprobs should match the selected token's logprob + if math.Abs(tlp.Logprob-result[0].Logprob) > 1e-6 { + t.Errorf("Selected token logprob mismatch: main=%f, in top=%f", + result[0].Logprob, tlp.Logprob) + } + break + } + } + if !found { + t.Errorf("Selected token %q not found in top logprobs", selectedText) + } + }) + } +} + +func TestCalculateLogprobsSoftmaxCorrectness(t *testing.T) { + // Test that softmax calculation is correct by verifying probabilities sum to 1 + decoder := func(tokenID int) string { + return string(rune('A' + tokenID)) + } + + tests := []struct { + name string + logits []float32 + }{ + { + name: "Small vocabulary", + logits: []float32{1.0, 2.0, 3.0}, + }, + { + name: "Large differences", + logits: []float32{10.0, 0.0, -10.0}, + }, + { + name: "All equal", + logits: []float32{5.0, 5.0, 5.0, 5.0, 5.0}, + }, + { + name: "Very large values", + logits: []float32{500.0, 499.0, 498.0}, + }, + { + name: "Very small values", + logits: []float32{-500.0, -499.0, -498.0}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Calculate logprobs for all tokens + var totalProb float64 + for i := range tt.logits { + result := CalculateLogprobs(tt.logits, i, 0, decoder) + if len(result) != 1 { + t.Fatalf("Expected 1 result, got %d", len(result)) + } + prob := math.Exp(result[0].Logprob) + totalProb += prob + + // Verify each probability is between 0 and 1 + if prob < 0 || prob > 1 { + t.Errorf("Token %d probability %f is out of range [0, 1]", i, prob) + } + } + + // Total probability should be very close to 1.0 (allowing for floating point errors) + if math.Abs(totalProb-1.0) > 1e-5 { + t.Errorf("Total probability sum is %f, expected 1.0", totalProb) + } + }) + } +} + +func TestCalculateLogprobsSelectedTokenCorrectness(t *testing.T) { + decoder := func(tokenID int) string { + return string(rune('A' + tokenID)) + } + + logits := []float32{3.0, 1.0, 2.0, 0.5} + + // Test that selecting different tokens gives the correct probabilities + // and that the highest logit has the highest probability + maxLogitIndex := 0 + maxLogitValue := logits[0] + for i, logit := range logits[1:] { + if logit > maxLogitValue { + maxLogitValue = logit + maxLogitIndex = i + 1 + } + } + + var maxProb float64 + var maxProbIndex int + + for i := range logits { + result := CalculateLogprobs(logits, i, 0, decoder) + prob := math.Exp(result[0].Logprob) + + if prob > maxProb { + maxProb = prob + maxProbIndex = i + } + + // Verify the token matches + expectedToken := decoder(i) + if result[0].Token != expectedToken { + t.Errorf("Token %d: expected token %q, got %q", i, expectedToken, result[0].Token) + } + } + + // The token with the highest logit should have the highest probability + if maxProbIndex != maxLogitIndex { + t.Errorf("Token with highest probability (%d) doesn't match token with highest logit (%d)", + maxProbIndex, maxLogitIndex) + } +} + +func TestCalculateLogprobsTopKOrdering(t *testing.T) { + tokens := map[int]string{ + 0: "first", + 1: "second", + 2: "third", + 3: "fourth", + 4: "fifth", + } + decoder := func(tokenID int) string { + return tokens[tokenID] + } + + // Logits in non-sorted order + logits := []float32{2.0, 5.0, 1.0, 4.0, 3.0} + // Expected order by probability: 1 (5.0), 3 (4.0), 4 (3.0), 0 (2.0), 2 (1.0) + expectedOrder := []string{"second", "fourth", "fifth", "first", "third"} + + result := CalculateLogprobs(logits, 0, 5, decoder) + + if len(result) != 1 { + t.Fatalf("Expected 1 result, got %d", len(result)) + } + + if len(result[0].TopLogprobs) != 5 { + t.Fatalf("Expected 5 top logprobs, got %d", len(result[0].TopLogprobs)) + } + + // Verify ordering matches expected + for i, tlp := range result[0].TopLogprobs { + if tlp.Token != expectedOrder[i] { + t.Errorf("Position %d: expected token %q, got %q", i, expectedOrder[i], tlp.Token) + } + } + + // Verify probabilities are in descending order + for i := 1; i < len(result[0].TopLogprobs); i++ { + if result[0].TopLogprobs[i].Logprob > result[0].TopLogprobs[i-1].Logprob { + t.Errorf("Probabilities not in descending order at position %d: %f > %f", + i, result[0].TopLogprobs[i].Logprob, result[0].TopLogprobs[i-1].Logprob) + } + } +} + +func TestLogprobsWithStopSequences(t *testing.T) { + tests := []struct { + name string + pendingResponses []string + pendingLogprobs []llm.Logprob + stop string + expectedResponses []string + expectedLogprobs int + }{ + { + name: "Single token stop", + pendingResponses: []string{"Hello", " world", "!"}, + pendingLogprobs: []llm.Logprob{ + {TokenLogprob: llm.TokenLogprob{Token: "Hello", Logprob: -0.1}}, + {TokenLogprob: llm.TokenLogprob{Token: " world", Logprob: -0.2}}, + {TokenLogprob: llm.TokenLogprob{Token: "!", Logprob: -0.3}}, + }, + stop: "!", + expectedResponses: []string{"Hello", " world"}, + expectedLogprobs: 2, + }, + { + name: "Multi-token stop sequence", + pendingResponses: []string{"Hello", " ", "there", "STOP"}, + pendingLogprobs: []llm.Logprob{ + {TokenLogprob: llm.TokenLogprob{Token: "Hello", Logprob: -0.1}}, + {TokenLogprob: llm.TokenLogprob{Token: " ", Logprob: -0.2}}, + {TokenLogprob: llm.TokenLogprob{Token: "there", Logprob: -0.3}}, + {TokenLogprob: llm.TokenLogprob{Token: "STOP", Logprob: -0.4}}, + }, + stop: "STOP", + expectedResponses: []string{"Hello", " ", "there"}, + expectedLogprobs: 3, + }, + { + name: "Partial token stop", + pendingResponses: []string{"Hello", " the", "re!"}, + pendingLogprobs: []llm.Logprob{ + {TokenLogprob: llm.TokenLogprob{Token: "Hello", Logprob: -0.1}}, + {TokenLogprob: llm.TokenLogprob{Token: " the", Logprob: -0.2}}, + {TokenLogprob: llm.TokenLogprob{Token: "re!", Logprob: -0.3}}, + }, + stop: "there!", + expectedResponses: []string{"Hello", " "}, + expectedLogprobs: 2, + }, + { + name: "Stop at beginning of last token", + pendingResponses: []string{"Hello", " world", "END"}, + pendingLogprobs: []llm.Logprob{ + {TokenLogprob: llm.TokenLogprob{Token: "Hello", Logprob: -0.1}}, + {TokenLogprob: llm.TokenLogprob{Token: " world", Logprob: -0.2}}, + {TokenLogprob: llm.TokenLogprob{Token: "END", Logprob: -0.3}}, + }, + stop: "END", + expectedResponses: []string{"Hello", " world"}, + expectedLogprobs: 2, + }, + { + name: "Multi-token stop across tokens", + pendingResponses: []string{"Text", " ", "with", " ", "stop", " ", "word"}, + pendingLogprobs: []llm.Logprob{ + {TokenLogprob: llm.TokenLogprob{Token: "Text", Logprob: -0.1}}, + {TokenLogprob: llm.TokenLogprob{Token: " ", Logprob: -0.2}}, + {TokenLogprob: llm.TokenLogprob{Token: "with", Logprob: -0.3}}, + {TokenLogprob: llm.TokenLogprob{Token: " ", Logprob: -0.4}}, + {TokenLogprob: llm.TokenLogprob{Token: "stop", Logprob: -0.5}}, + {TokenLogprob: llm.TokenLogprob{Token: " ", Logprob: -0.6}}, + {TokenLogprob: llm.TokenLogprob{Token: "word", Logprob: -0.7}}, + }, + stop: "stop word", + expectedResponses: []string{"Text", " ", "with", " "}, + expectedLogprobs: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Simulate the stop sequence detection and truncation + origLen := len(tt.pendingResponses) + responses, tokenTruncated := TruncateStop(tt.pendingResponses, tt.stop) + newLen := len(responses) + + // Simulate logprobs truncation + logprobs := make([]llm.Logprob, len(tt.pendingLogprobs)) + copy(logprobs, tt.pendingLogprobs) + + origLogprobsLen := len(logprobs) + numTokensRemoved := origLen - newLen + newLogprobsLen := origLogprobsLen - numTokensRemoved + if newLogprobsLen < 0 { + newLogprobsLen = 0 + } + logprobs = logprobs[:newLogprobsLen] + + // Verify responses were truncated correctly + if len(responses) != len(tt.expectedResponses) { + t.Errorf("Expected %d responses, got %d", len(tt.expectedResponses), len(responses)) + } + + // Verify logprobs count matches truncated responses + if len(logprobs) != tt.expectedLogprobs { + t.Errorf("Expected %d logprobs after truncation, got %d", tt.expectedLogprobs, len(logprobs)) + } + + // Verify logprobs count matches response count + if len(logprobs) != len(responses) { + t.Errorf("Logprobs count (%d) doesn't match responses count (%d)", len(logprobs), len(responses)) + } + + // Verify the correct logprobs were kept (skip last token if it was truncated) + // When tokenTruncated is true, the last response token may not match the logprob token + checkLen := len(logprobs) + if tokenTruncated && checkLen > 0 { + checkLen-- // Skip checking the last token when it was partially truncated + } + + for i := range checkLen { + if i < len(responses) && logprobs[i].Token != responses[i] { + t.Errorf("Logprob[%d] token %q doesn't match response[%d] %q", + i, logprobs[i].Token, i, responses[i]) + } + } + }) + } +} diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 87b43256ab..16c84a78bb 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -28,6 +28,12 @@ import ( "github.com/ollama/ollama/runner/common" ) +// response contains a piece of generated text along with optional logprobs +type response struct { + content string + logprobs []llm.Logprob +} + // input is an element of the prompt to process, either // a token or an image embedding (generated from a vision projector) type input struct { @@ -53,11 +59,14 @@ type Sequence struct { // tokens that have been generated but not returned yet (e.g. for stop sequences) pendingResponses []string + // logprobs for tokens that haven't been returned yet + pendingLogprobs []llm.Logprob + // input cache being used by this sequence cache *InputCacheSlot // channel to send responses over - responses chan string + responses chan response // channel to stop decoding (such as if the remote connection is closed) quit chan bool @@ -84,6 +93,10 @@ type Sequence struct { doneReason llm.DoneReason + // logprobs configuration + logprobs bool + topLogprobs int + // Metrics processingDuration time.Duration generationDuration time.Duration @@ -99,6 +112,8 @@ type NewSequenceParams struct { embedding bool shift bool truncate bool + logprobs bool + topLogprobs int } var errorInputTooLong = errors.New("the input length exceeds the context length") @@ -155,7 +170,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe numPromptInputs: len(inputs), numPredict: params.numPredict, pendingResponses: make([]string, 0), - responses: make(chan string, 100), + responses: make(chan response, 100), quit: make(chan bool, 1), embedding: make(chan []float32, 1), samplingCtx: sc, @@ -163,9 +178,16 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe stop: params.stop, numKeep: params.numKeep, shift: params.shift, + logprobs: params.logprobs, + topLogprobs: params.topLogprobs, }, nil } +// calculateLogprobsLlama converts raw logits to log probabilities and finds top K tokens +func calculateLogprobsLlama(logits []float32, selectedToken int, topK int, model *llama.Model) []llm.Logprob { + return common.CalculateLogprobs(logits, selectedToken, topK, model.TokenToPiece) +} + // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // generating image embeddings for each image @@ -294,7 +316,9 @@ func (s *Server) allNil() bool { func flushPending(seq *Sequence) bool { joined := strings.Join(seq.pendingResponses, "") + logprobs := seq.pendingLogprobs seq.pendingResponses = []string{} + seq.pendingLogprobs = []llm.Logprob{} // Check if there are any partial UTF-8 characters remaining. // We already check and queue as we are generating but some may @@ -311,7 +335,7 @@ func flushPending(seq *Sequence) bool { } select { - case seq.responses <- joined: + case seq.responses <- response{content: joined, logprobs: logprobs}: return true case <-seq.quit: return false @@ -526,6 +550,15 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) continue } + // Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens) + if seq.logprobs { + logits := s.lc.GetLogitsIth(seq.iBatch) + if logits != nil { + logprobs := calculateLogprobsLlama(logits, token, seq.topLogprobs, s.model) + seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...) + } + } + seq.inputs = []input{{token: token}} seq.pendingResponses = append(seq.pendingResponses, piece) @@ -539,6 +572,17 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop) newLen := len(seq.pendingResponses) + // Truncate logprobs to match the truncated responses + if seq.logprobs { + origLogprobsLen := len(seq.pendingLogprobs) + numTokensRemoved := origLen - newLen + newLogprobsLen := origLogprobsLen - numTokensRemoved + if newLogprobsLen < 0 { + newLogprobsLen = 0 + } + seq.pendingLogprobs = seq.pendingLogprobs[:newLogprobsLen] + } + // Update the cache based on the tokens that will be returned: // - We have 1 token more than is currently in the cache because // the last one generated wasn't submitted to Decode @@ -618,6 +662,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { embedding: false, shift: req.Shift, truncate: req.Truncate, + logprobs: req.Logprobs, + topLogprobs: req.TopLogprobs, }) if err != nil { if errors.Is(err, errorInputTooLong) { @@ -669,10 +715,11 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { case <-r.Context().Done(): close(seq.quit) return - case content, ok := <-seq.responses: + case resp, ok := <-seq.responses: if ok { if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ - Content: content, + Content: resp.content, + Logprobs: resp.logprobs, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) close(seq.quit) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 3e8c1e2276..1533908686 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -41,6 +41,12 @@ import ( _ "github.com/ollama/ollama/model/models" ) +// response contains a piece of generated text along with optional logprobs +type response struct { + content string + logprobs []llm.Logprob +} + type Sequence struct { // ctxs are used for allocating tensors that last the lifetime of the sequence, such as // multimodal embeddings @@ -61,11 +67,14 @@ type Sequence struct { // tokens that have been generated but not returned yet (e.g. for stop sequences) pendingResponses []string + // logprobs for tokens that haven't been returned yet + pendingLogprobs []llm.Logprob + // input cache being used by this sequence cache *InputCacheSlot // channel to send responses over - responses chan string + responses chan response // channel to stop decoding (such as if the remote connection is closed) quit chan bool @@ -93,6 +102,10 @@ type Sequence struct { doneReason llm.DoneReason + // logprobs configuration + logprobs bool + topLogprobs int + // Metrics startedAt, lastUpdatedAt time.Time processingDuration time.Duration @@ -102,13 +115,15 @@ type Sequence struct { } type NewSequenceParams struct { - numPredict int - stop []string - numKeep int32 - sampler sample.Sampler - embedding bool - shift bool - truncate bool + numPredict int + stop []string + numKeep int32 + sampler sample.Sampler + embedding bool + shift bool + truncate bool + logprobs bool + topLogprobs int } var errorInputTooLong = errors.New("the input length exceeds the context length") @@ -181,7 +196,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe numPromptInputs: len(inputs), numPredict: params.numPredict, pendingResponses: make([]string, 0), - responses: make(chan string, 100), + responses: make(chan response, 100), quit: make(chan bool, 1), embedding: make(chan []float32, 1), sampler: params.sampler, @@ -189,9 +204,20 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe stop: params.stop, numKeep: params.numKeep, shift: params.shift, + logprobs: params.logprobs, + topLogprobs: params.topLogprobs, }, nil } +// calculateLogprobs converts raw logits to log probabilities and finds top K tokens +func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob { + decoder := func(tokenID int) string { + text, _ := textProcessor.Decode([]int32{int32(tokenID)}) + return text + } + return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder) +} + // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // decoding images @@ -371,7 +397,9 @@ func (s *Server) allNil() bool { func flushPending(seq *Sequence) bool { joined := strings.Join(seq.pendingResponses, "") + logprobs := seq.pendingLogprobs seq.pendingResponses = []string{} + seq.pendingLogprobs = []llm.Logprob{} // Check if there are any partial UTF-8 characters remaining. // We already check and queue as we are generating but some may @@ -388,7 +416,7 @@ func flushPending(seq *Sequence) bool { } select { - case seq.responses <- joined: + case seq.responses <- response{content: joined, logprobs: logprobs}: return true case <-seq.quit: return false @@ -729,7 +757,8 @@ func (s *Server) computeBatch(activeBatch batchState) { // sample a token vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0) logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches) - token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize]) + logits := outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize] + token, err := seq.sampler.Sample(logits) if err != nil { panic("failed to sample token") } @@ -751,6 +780,12 @@ func (s *Server) computeBatch(activeBatch batchState) { panic("failed to decode token") } + // Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens) + if seq.logprobs { + logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor)) + seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...) + } + seq.pendingResponses = append(seq.pendingResponses, piece) sequence := strings.Join(seq.pendingResponses, "") @@ -762,6 +797,17 @@ func (s *Server) computeBatch(activeBatch batchState) { seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop) newLen := len(seq.pendingResponses) + // Truncate logprobs to match the truncated responses + if seq.logprobs { + origLogprobsLen := len(seq.pendingLogprobs) + numTokensRemoved := origLen - newLen + newLogprobsLen := origLogprobsLen - numTokensRemoved + if newLogprobsLen < 0 { + newLogprobsLen = 0 + } + seq.pendingLogprobs = seq.pendingLogprobs[:newLogprobsLen] + } + // Update the cache based on the tokens that will be returned: // - We have 1 token more than is currently in the cache because // the last one generated wasn't submitted to Decode @@ -845,13 +891,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { ) seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ - numPredict: req.Options.NumPredict, - stop: req.Options.Stop, - numKeep: int32(req.Options.NumKeep), - sampler: sampler, - embedding: false, - shift: req.Shift, - truncate: req.Truncate, + numPredict: req.Options.NumPredict, + stop: req.Options.Stop, + numKeep: int32(req.Options.NumKeep), + sampler: sampler, + embedding: false, + shift: req.Shift, + truncate: req.Truncate, + logprobs: req.Logprobs, + topLogprobs: req.TopLogprobs, }) if err != nil { if errors.Is(err, errorInputTooLong) { @@ -903,10 +951,11 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { case <-r.Context().Done(): close(seq.quit) return - case content, ok := <-seq.responses: + case resp, ok := <-seq.responses: if ok { if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ - Content: content, + Content: resp.content, + Logprobs: resp.logprobs, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) close(seq.quit) diff --git a/server/logprob.go b/server/logprob.go new file mode 100644 index 0000000000..51996c2a11 --- /dev/null +++ b/server/logprob.go @@ -0,0 +1,29 @@ +package server + +import ( + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/llm" +) + +// toAPILogprobs converts llm.Logprobs to api.Logprobs +func toAPILogprobs(logprobs []llm.Logprob) []api.Logprob { + result := make([]api.Logprob, len(logprobs)) + for i, lp := range logprobs { + result[i] = api.Logprob{ + TokenLogprob: api.TokenLogprob{ + Token: lp.Token, + Logprob: lp.Logprob, + }, + } + if len(lp.TopLogprobs) > 0 { + result[i].TopLogprobs = make([]api.TokenLogprob, len(lp.TopLogprobs)) + for j, tlp := range lp.TopLogprobs { + result[i].TopLogprobs[j] = api.TokenLogprob{ + Token: tlp.Token, + Logprob: tlp.Logprob, + } + } + } + } + return result +} diff --git a/server/routes.go b/server/routes.go index 28d70d62f8..38ce4e4d4e 100644 --- a/server/routes.go +++ b/server/routes.go @@ -183,6 +183,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + if req.TopLogprobs < 0 || req.TopLogprobs > 20 { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"}) + return + } + name := model.ParseName(req.Model) if !name.IsValid() { // Ideally this is "invalid model name" but we're keeping with @@ -212,6 +217,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } + if req.TopLogprobs < 0 || req.TopLogprobs > 20 { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"}) + return + } + if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { origModel := req.Model @@ -502,12 +512,14 @@ func (s *Server) GenerateHandler(c *gin.Context) { var sb strings.Builder defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, - Shift: req.Shift == nil || *req.Shift, - Truncate: req.Truncate == nil || *req.Truncate, + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, + Shift: req.Shift == nil || *req.Shift, + Truncate: req.Truncate == nil || *req.Truncate, + Logprobs: req.Logprobs, + TopLogprobs: req.TopLogprobs, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{ Model: req.Model, @@ -520,6 +532,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { EvalCount: cr.EvalCount, EvalDuration: cr.EvalDuration, }, + Logprobs: toAPILogprobs(cr.Logprobs), } if builtinParser != nil { @@ -580,6 +593,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { if req.Stream != nil && !*req.Stream { var r api.GenerateResponse + var allLogprobs []api.Logprob var sbThinking strings.Builder var sbContent strings.Builder for rr := range ch { @@ -588,6 +602,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { sbThinking.WriteString(t.Thinking) sbContent.WriteString(t.Response) r = t + // Accumulate logprobs from all chunks for non-streaming response + if len(t.Logprobs) > 0 { + allLogprobs = append(allLogprobs, t.Logprobs...) + } case gin.H: msg, ok := t["error"].(string) if !ok { @@ -609,6 +627,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { r.Thinking = sbThinking.String() r.Response = sbContent.String() + r.Logprobs = allLogprobs c.JSON(http.StatusOK, r) return @@ -1834,6 +1853,11 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + if req.TopLogprobs < 0 || req.TopLogprobs > 20 { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"}) + return + } + name := model.ParseName(req.Model) if !name.IsValid() { c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) @@ -1859,6 +1883,11 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + if req.TopLogprobs < 0 || req.TopLogprobs > 20 { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "top_logprobs must be between 0 and 20"}) + return + } + // expire the runner if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 { s.sched.expireRunner(m) @@ -2104,12 +2133,14 @@ func (s *Server) ChatHandler(c *gin.Context) { // sets up new context given parent context per request ctx, cancel := context.WithCancel(c.Request.Context()) err := r.Completion(ctx, llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: currentFormat, - Options: opts, - Shift: req.Shift == nil || *req.Shift, - Truncate: truncate, + Prompt: prompt, + Images: images, + Format: currentFormat, + Options: opts, + Shift: req.Shift == nil || *req.Shift, + Truncate: truncate, + Logprobs: req.Logprobs, + TopLogprobs: req.TopLogprobs, }, func(r llm.CompletionResponse) { res := api.ChatResponse{ Model: req.Model, @@ -2122,7 +2153,9 @@ func (s *Server) ChatHandler(c *gin.Context) { EvalCount: r.EvalCount, EvalDuration: r.EvalDuration, }, + Logprobs: toAPILogprobs(r.Logprobs), } + if r.Done { res.DoneReason = r.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) @@ -2251,6 +2284,7 @@ func (s *Server) ChatHandler(c *gin.Context) { if req.Stream != nil && !*req.Stream { var resp api.ChatResponse var toolCalls []api.ToolCall + var allLogprobs []api.Logprob var sbThinking strings.Builder var sbContent strings.Builder for rr := range ch { @@ -2262,6 +2296,10 @@ func (s *Server) ChatHandler(c *gin.Context) { if len(req.Tools) > 0 { toolCalls = append(toolCalls, t.Message.ToolCalls...) } + // Accumulate logprobs from all chunks for non-streaming response + if len(t.Logprobs) > 0 { + allLogprobs = append(allLogprobs, t.Logprobs...) + } case gin.H: msg, ok := t["error"].(string) if !ok { @@ -2283,6 +2321,7 @@ func (s *Server) ChatHandler(c *gin.Context) { resp.Message.Content = sbContent.String() resp.Message.Thinking = sbThinking.String() + resp.Logprobs = allLogprobs if len(toolCalls) > 0 { resp.Message.ToolCalls = toolCalls diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index ed5922f2d2..a6be3bf308 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -1184,6 +1184,86 @@ func TestGenerate(t *testing.T) { }) } +func TestGenerateLogprobs(t *testing.T) { + t.Run("invalid top_logprobs negative", func(t *testing.T) { + gin.SetMode(gin.TestMode) + s := Server{} + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test", + Prompt: "Hello", + TopLogprobs: -1, + }) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"top_logprobs must be between 0 and 20"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("invalid top_logprobs too high", func(t *testing.T) { + gin.SetMode(gin.TestMode) + s := Server{} + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test", + Prompt: "Hello", + TopLogprobs: 21, + }) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"top_logprobs must be between 0 and 20"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) +} + +func TestChatLogprobs(t *testing.T) { + t.Run("invalid top_logprobs negative", func(t *testing.T) { + gin.SetMode(gin.TestMode) + s := Server{} + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test", + Messages: []api.Message{ + {Role: "user", Content: "Hello"}, + }, + TopLogprobs: -1, + }) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"top_logprobs must be between 0 and 20"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("invalid top_logprobs too high", func(t *testing.T) { + gin.SetMode(gin.TestMode) + s := Server{} + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test", + Messages: []api.Message{ + {Role: "user", Content: "Hello"}, + }, + TopLogprobs: 21, + }) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"top_logprobs must be between 0 and 20"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) +} + func TestChatWithPromptEndingInThinkTag(t *testing.T) { gin.SetMode(gin.TestMode)