This commit is contained in:
ParthSareen 2024-12-13 14:03:10 -08:00
parent 844899440a
commit d7e7e6a01e
2 changed files with 84 additions and 0 deletions

View File

@ -737,3 +737,14 @@ func SchemaToGrammar(schema []byte) []byte {
}
return buf[:n]
}
// GetLogits returns the logits from the last decode operation.
// The returned slice has length equal to the vocabulary size.
func (c *Context) GetLogits() []float32 {
logits := unsafe.Pointer(C.llama_get_logits(c.c))
if logits == nil {
return nil
}
// Get the number of vocabulary tokens to determine array size
vocabSize := c.Model().NumVocab()

View File

@ -1003,3 +1003,76 @@ func Execute(args []string) error {
cancel()
return nil
}
// Helper function to get top K logits and convert to log probabilities
func getTopLogits(logits []float32, k int, model *llama.Model) []api.LogProbs {
if k <= 0 {
return nil
}
// Convert logits to probabilities using softmax
probs := softmax(logits)
// Create slice of index/probability pairs
pairs := make([]struct {
token int
prob float32
}, len(probs))
for i, p := range probs {
pairs[i] = struct {
token int
prob float32
}{i, p}
}
// Sort by probability (descending)
sort.Slice(pairs, func(i, j int) bool {
return pairs[i].prob > pairs[j].prob
})
// Take top K
k = min(k, len(pairs))
result := make([]api.LogProbs, k)
for i := 0; i < k; i++ {
result[i] = api.LogProbs{
TopLogprobs: []api.TokenLogprob{
{
Token: model.TokenToPiece(pairs[i].token),
Logprob: float32(math.Log(float64(pairs[i].prob))),
},
},
}
}
return result
}
// Helper function to compute softmax
func softmax(logits []float32) []float32 {
probs := make([]float32, len(logits))
// Find max for numerical stability
max := float32(math.Inf(-1))
for _, l := range logits {
if l > max {
max = l
}
}
// Compute exp(x - max) and sum
sum := float32(0)
for i, l := range logits {
ex := float32(math.Exp(float64(l - max)))
probs[i] = ex
sum += ex
}
// Normalize
for i := range probs {
probs[i] /= sum
}
return probs
}