mirror of
https://github.com/ollama/ollama.git
synced 2025-12-08 17:41:23 +01:00
logprob: add bytes to logprobs (#13068)
This commit is contained in:
@@ -14,6 +14,23 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func assertBytesMatchToken(t *testing.T, label, token string, ints []int) {
|
||||
t.Helper()
|
||||
|
||||
raw := []byte(token)
|
||||
if len(ints) != len(raw) {
|
||||
t.Errorf("%s expected %d bytes for token %q, got %d (%v)", label, len(raw), token, len(ints), ints)
|
||||
return
|
||||
}
|
||||
|
||||
for i, b := range raw {
|
||||
if ints[i] != int(b) {
|
||||
t.Errorf("%s byte[%d] mismatch for token %q: got %d want %d", label, i, token, ints[i], int(b))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIGenerate(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 30 * time.Second
|
||||
@@ -466,6 +483,7 @@ func TestAPIGenerateLogprobs(t *testing.T) {
|
||||
if lp.Logprob > 0 {
|
||||
t.Errorf("logprob[%d] has positive logprob %f (should be <= 0)", i, lp.Logprob)
|
||||
}
|
||||
assertBytesMatchToken(t, fmt.Sprintf("generate logprob[%d]", i), lp.Token, lp.Bytes)
|
||||
|
||||
// Check top_logprobs if requested
|
||||
if test.topLogprobs > 0 {
|
||||
@@ -482,6 +500,9 @@ func TestAPIGenerateLogprobs(t *testing.T) {
|
||||
t.Errorf("logprob[%d].top_logprobs not sorted: %f < %f", i, lp.TopLogprobs[j-1].Logprob, lp.TopLogprobs[j].Logprob)
|
||||
}
|
||||
}
|
||||
for j, top := range lp.TopLogprobs {
|
||||
assertBytesMatchToken(t, fmt.Sprintf("generate logprob[%d].top[%d]", i, j), top.Token, top.Bytes)
|
||||
}
|
||||
} else if len(lp.TopLogprobs) > 0 {
|
||||
t.Errorf("logprob[%d] has top_logprobs but none were requested", i)
|
||||
}
|
||||
@@ -544,11 +565,15 @@ func TestAPIChatLogprobs(t *testing.T) {
|
||||
if lp.Logprob > 0 {
|
||||
t.Errorf("logprob[%d] has positive logprob %f", i, lp.Logprob)
|
||||
}
|
||||
assertBytesMatchToken(t, fmt.Sprintf("chat logprob[%d]", i), lp.Token, lp.Bytes)
|
||||
if len(lp.TopLogprobs) == 0 {
|
||||
t.Errorf("logprob[%d] expected top_logprobs but got none", i)
|
||||
}
|
||||
if len(lp.TopLogprobs) > 3 {
|
||||
t.Errorf("logprob[%d] has %d top_logprobs, expected max 3", i, len(lp.TopLogprobs))
|
||||
}
|
||||
for j, top := range lp.TopLogprobs {
|
||||
assertBytesMatchToken(t, fmt.Sprintf("chat logprob[%d].top[%d]", i, j), top.Token, top.Bytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user