print logprobs

This commit is contained in:
Bruce MacDonald 2025-02-12 16:36:03 -08:00
parent 82658c3eec
commit 7d16ec8fe8
2 changed files with 95 additions and 3 deletions

View File

@ -50,7 +50,7 @@ import (
_ "github.com/ollama/ollama/llama/llama.cpp/common"
_ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
_ "github.com/ollama/ollama/llama/llama.cpp/src"
"github.com/ollama/ollama/ml/backend/ggml/ggml/src"
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
)
func BackendInit() {
@ -220,6 +220,31 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
return embeddings
}
// GetLogits returns the logits from the last decode operation.
// The returned slice has length equal to the vocabulary size.
func (c *Context) GetLogits() []float32 {
logits := unsafe.Pointer(C.llama_get_logits(c.c))
if logits == nil {
return nil
}
// Get the number of vocabulary tokens to determine array size
vocabSize := c.Model().NumVocab()
return unsafe.Slice((*float32)(logits), vocabSize)
}
func (m *Model) Detokenize(tokens []int) (string, error) {
var text string
for _, token := range tokens {
piece := m.TokenToPiece(token)
if piece == "" {
return "", fmt.Errorf("failed to convert token %d to piece", token)
}
text += piece
}
return text, nil
}
type ModelParams struct {
NumGpuLayers int
MainGpu int

View File

@ -8,12 +8,14 @@ import (
"fmt"
"log"
"log/slog"
"math"
"net"
"net/http"
"os"
"path/filepath"
"regexp"
"runtime"
"sort"
"strconv"
"strings"
"sync"
@ -83,6 +85,8 @@ type Sequence struct {
doneReason string
logits []float32
// Metrics
startProcessingTime time.Time
startGenerationTime time.Time
@ -274,6 +278,9 @@ func (s *Server) allNil() bool {
}
func flushPending(seq *Sequence) bool {
if len(seq.pendingResponses) == 0 {
return true
}
joined := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{}
@ -287,8 +294,11 @@ func flushPending(seq *Sequence) bool {
joined = joined[:len(joined)-1]
}
if len(joined) == 0 {
return true
// Add logits if requested and available
wantLogits := true
if wantLogits && seq.logits != nil {
// resp.Logits = seq.logits
seq.logits = nil
}
select {
@ -350,6 +360,57 @@ func (s *Server) run(ctx context.Context) {
}
}
// TokenData represents probability information for a token
type TokenData struct {
TokenID int
Logit float32
Prob float32
LogProb float32
}
// getTokenProbabilities returns sorted token probabilities for a specific token index
func (s *Server) getTokenProbabilities(seq *Sequence) []TokenData {
// Get logits for the specific token index
logits := s.lc.GetLogits()
seq.logits = make([]float32, len(logits))
copy(seq.logits, logits)
vocabSize := s.model.NumVocab()
probs := make([]TokenData, vocabSize)
// Initialize token data with logits
for i := 0; i < vocabSize; i++ {
probs[i] = TokenData{
TokenID: i,
Logit: logits[i],
}
}
// Sort tokens by logits in descending order
sort.Slice(probs, func(i, j int) bool {
return probs[i].Logit > probs[j].Logit
})
// Apply softmax
maxLogit := probs[0].Logit
var sum float32 = 0.0
for i := range probs {
p := float32(math.Exp(float64(probs[i].Logit - maxLogit)))
probs[i].Prob = p
sum += p
}
// Normalize probabilities and calculate log probs
for i := range probs {
prob := probs[i].Prob / sum
probs[i].Prob = prob
probs[i].LogProb = float32(math.Log(float64(prob)))
}
return probs
}
// TODO (jmorganca): processBatch should be simplified, removing:
// * sampling
// * stop token checking
@ -483,6 +544,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
seq.numPredicted++
// TODO: only do this when flag specified
probs := s.getTokenProbabilities(seq)
for i := range 10 {
slog.Debug("top 10 tokens", "token", probs[i].TokenID, "prob", probs[i].Prob, "logit", probs[i].Logit, "piece", s.model.TokenToPiece(probs[i].TokenID))
}
// if it's an end of sequence token, break
if s.model.TokenIsEog(token) {
// TODO (jmorganca): we should send this back