logprob: add bytes to logprobs (#13068)

This commit is contained in:
Parth Sareen
2025-11-13 13:49:25 -08:00
committed by GitHub
parent b48083f33f
commit c114987523
4 changed files with 312 additions and 0 deletions

View File

@@ -1220,6 +1220,139 @@ func TestGenerateLogprobs(t *testing.T) {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("returns logprob bytes when requested", func(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := &mockRunner{}
expectedPrimary := llm.TokenLogprob{
Token: "Hi",
Logprob: -0.01,
}
expectedAlternatives := []llm.TokenLogprob{
{
Token: "Hello",
Logprob: -0.25,
},
{
Token: "Hey",
Logprob: -0.5,
},
}
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
fn(llm.CompletionResponse{
Content: "Hi",
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
Logprobs: []llm.Logprob{
{
TokenLogprob: expectedPrimary,
TopLogprobs: expectedAlternatives,
},
},
})
return nil
}
s := &Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
req.successCh <- &runnerRef{llama: mock}
return false
},
},
}
go s.sched.Run(t.Context())
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
if w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-logprob-bytes",
Files: map[string]string{"file.gguf": digest},
Template: `{{ .Prompt }}`,
Stream: &stream,
}); w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
stream := false
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-logprob-bytes",
Prompt: "Hi",
Stream: &stream,
Logprobs: true,
TopLogprobs: len(expectedAlternatives),
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var resp api.GenerateResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if len(resp.Logprobs) != 1 {
t.Fatalf("expected 1 logprob entry, got %d", len(resp.Logprobs))
}
expectedPrimaryBytes := stringToByteInts(expectedPrimary.Token)
expectedAlternativesBytes := make([][]int, len(expectedAlternatives))
for i, alternative := range expectedAlternatives {
expectedAlternativesBytes[i] = stringToByteInts(alternative.Token)
}
if diff := cmp.Diff(expectedPrimaryBytes, resp.Logprobs[0].Bytes); diff != "" {
t.Fatalf("primary token bytes mismatch (-want +got):\n%s", diff)
}
if len(resp.Logprobs[0].TopLogprobs) != len(expectedAlternatives) {
t.Fatalf("expected %d top logprobs, got %d", len(expectedAlternatives), len(resp.Logprobs[0].TopLogprobs))
}
for i, top := range resp.Logprobs[0].TopLogprobs {
if diff := cmp.Diff(expectedAlternativesBytes[i], top.Bytes); diff != "" {
t.Fatalf("top logprob[%d] bytes mismatch (-want +got):\n%s", i, diff)
}
}
})
}
func TestChatLogprobs(t *testing.T) {
@@ -1262,6 +1395,142 @@ func TestChatLogprobs(t *testing.T) {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("returns logprob bytes when requested", func(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := &mockRunner{}
expectedPrimary := llm.TokenLogprob{
Token: "Hi",
Logprob: -0.02,
}
expectedAlternatives := []llm.TokenLogprob{
{
Token: "Hello",
Logprob: -0.3,
},
{
Token: "Hey",
Logprob: -0.45,
},
}
expectedPrimaryBytes := stringToByteInts(expectedPrimary.Token)
expectedAlternativesBytes := make([][]int, len(expectedAlternatives))
for i, alternative := range expectedAlternatives {
expectedAlternativesBytes[i] = stringToByteInts(alternative.Token)
}
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
fn(llm.CompletionResponse{
Content: "Hi",
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
Logprobs: []llm.Logprob{
{
TokenLogprob: expectedPrimary,
TopLogprobs: expectedAlternatives,
},
},
})
return nil
}
s := &Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
req.successCh <- &runnerRef{llama: mock}
return false
},
},
}
go s.sched.Run(t.Context())
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
if w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-chat-logprob-bytes",
Files: map[string]string{"file.gguf": digest},
Template: `{{- range .Messages }}{{ .Role }}: {{ .Content }}
{{ end }}`,
Stream: &stream,
}); w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
stream := false
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-chat-logprob-bytes",
Messages: []api.Message{
{Role: "user", Content: "Say hi"},
},
Stream: &stream,
Logprobs: true,
TopLogprobs: len(expectedAlternatives),
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var resp api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if len(resp.Logprobs) != 1 {
t.Fatalf("expected 1 logprob entry, got %d", len(resp.Logprobs))
}
if diff := cmp.Diff(expectedPrimaryBytes, resp.Logprobs[0].Bytes); diff != "" {
t.Fatalf("primary token bytes mismatch (-want +got):\n%s", diff)
}
if len(resp.Logprobs[0].TopLogprobs) != len(expectedAlternatives) {
t.Fatalf("expected %d top logprobs, got %d", len(expectedAlternatives), len(resp.Logprobs[0].TopLogprobs))
}
for i, top := range resp.Logprobs[0].TopLogprobs {
if diff := cmp.Diff(expectedAlternativesBytes[i], top.Bytes); diff != "" {
t.Fatalf("top logprob[%d] bytes mismatch (-want +got):\n%s", i, diff)
}
}
})
}
func TestChatWithPromptEndingInThinkTag(t *testing.T) {