mirror of
https://github.com/ollama/ollama.git
synced 2025-04-11 05:09:45 +02:00
wip
This commit is contained in:
parent
844899440a
commit
d7e7e6a01e
@ -737,3 +737,14 @@ func SchemaToGrammar(schema []byte) []byte {
|
||||
}
|
||||
return buf[:n]
|
||||
}
|
||||
|
||||
// GetLogits returns the logits from the last decode operation.
|
||||
// The returned slice has length equal to the vocabulary size.
|
||||
func (c *Context) GetLogits() []float32 {
|
||||
logits := unsafe.Pointer(C.llama_get_logits(c.c))
|
||||
if logits == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the number of vocabulary tokens to determine array size
|
||||
vocabSize := c.Model().NumVocab()
|
||||
|
@ -1003,3 +1003,76 @@ func Execute(args []string) error {
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper function to get top K logits and convert to log probabilities
|
||||
func getTopLogits(logits []float32, k int, model *llama.Model) []api.LogProbs {
|
||||
if k <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert logits to probabilities using softmax
|
||||
probs := softmax(logits)
|
||||
|
||||
// Create slice of index/probability pairs
|
||||
pairs := make([]struct {
|
||||
token int
|
||||
prob float32
|
||||
}, len(probs))
|
||||
|
||||
for i, p := range probs {
|
||||
pairs[i] = struct {
|
||||
token int
|
||||
prob float32
|
||||
}{i, p}
|
||||
}
|
||||
|
||||
// Sort by probability (descending)
|
||||
sort.Slice(pairs, func(i, j int) bool {
|
||||
return pairs[i].prob > pairs[j].prob
|
||||
})
|
||||
|
||||
// Take top K
|
||||
k = min(k, len(pairs))
|
||||
result := make([]api.LogProbs, k)
|
||||
|
||||
for i := 0; i < k; i++ {
|
||||
result[i] = api.LogProbs{
|
||||
TopLogprobs: []api.TokenLogprob{
|
||||
{
|
||||
Token: model.TokenToPiece(pairs[i].token),
|
||||
Logprob: float32(math.Log(float64(pairs[i].prob))),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Helper function to compute softmax
|
||||
func softmax(logits []float32) []float32 {
|
||||
probs := make([]float32, len(logits))
|
||||
|
||||
// Find max for numerical stability
|
||||
max := float32(math.Inf(-1))
|
||||
for _, l := range logits {
|
||||
if l > max {
|
||||
max = l
|
||||
}
|
||||
}
|
||||
|
||||
// Compute exp(x - max) and sum
|
||||
sum := float32(0)
|
||||
for i, l := range logits {
|
||||
ex := float32(math.Exp(float64(l - max)))
|
||||
probs[i] = ex
|
||||
sum += ex
|
||||
}
|
||||
|
||||
// Normalize
|
||||
for i := range probs {
|
||||
probs[i] /= sum
|
||||
}
|
||||
|
||||
return probs
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user