mirror of
https://github.com/ollama/ollama.git
synced 2025-07-14 07:22:49 +02:00
wip
This commit is contained in:
@ -737,3 +737,14 @@ func SchemaToGrammar(schema []byte) []byte {
|
|||||||
}
|
}
|
||||||
return buf[:n]
|
return buf[:n]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetLogits returns the logits from the last decode operation.
|
||||||
|
// The returned slice has length equal to the vocabulary size.
|
||||||
|
func (c *Context) GetLogits() []float32 {
|
||||||
|
logits := unsafe.Pointer(C.llama_get_logits(c.c))
|
||||||
|
if logits == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the number of vocabulary tokens to determine array size
|
||||||
|
vocabSize := c.Model().NumVocab()
|
||||||
|
@ -1003,3 +1003,76 @@ func Execute(args []string) error {
|
|||||||
cancel()
|
cancel()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function to get top K logits and convert to log probabilities
|
||||||
|
func getTopLogits(logits []float32, k int, model *llama.Model) []api.LogProbs {
|
||||||
|
if k <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert logits to probabilities using softmax
|
||||||
|
probs := softmax(logits)
|
||||||
|
|
||||||
|
// Create slice of index/probability pairs
|
||||||
|
pairs := make([]struct {
|
||||||
|
token int
|
||||||
|
prob float32
|
||||||
|
}, len(probs))
|
||||||
|
|
||||||
|
for i, p := range probs {
|
||||||
|
pairs[i] = struct {
|
||||||
|
token int
|
||||||
|
prob float32
|
||||||
|
}{i, p}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by probability (descending)
|
||||||
|
sort.Slice(pairs, func(i, j int) bool {
|
||||||
|
return pairs[i].prob > pairs[j].prob
|
||||||
|
})
|
||||||
|
|
||||||
|
// Take top K
|
||||||
|
k = min(k, len(pairs))
|
||||||
|
result := make([]api.LogProbs, k)
|
||||||
|
|
||||||
|
for i := 0; i < k; i++ {
|
||||||
|
result[i] = api.LogProbs{
|
||||||
|
TopLogprobs: []api.TokenLogprob{
|
||||||
|
{
|
||||||
|
Token: model.TokenToPiece(pairs[i].token),
|
||||||
|
Logprob: float32(math.Log(float64(pairs[i].prob))),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to compute softmax
|
||||||
|
func softmax(logits []float32) []float32 {
|
||||||
|
probs := make([]float32, len(logits))
|
||||||
|
|
||||||
|
// Find max for numerical stability
|
||||||
|
max := float32(math.Inf(-1))
|
||||||
|
for _, l := range logits {
|
||||||
|
if l > max {
|
||||||
|
max = l
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute exp(x - max) and sum
|
||||||
|
sum := float32(0)
|
||||||
|
for i, l := range logits {
|
||||||
|
ex := float32(math.Exp(float64(l - max)))
|
||||||
|
probs[i] = ex
|
||||||
|
sum += ex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize
|
||||||
|
for i := range probs {
|
||||||
|
probs[i] /= sum
|
||||||
|
}
|
||||||
|
|
||||||
|
return probs
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user