mirror of
https://github.com/ollama/ollama.git
synced 2025-03-18 22:01:47 +01:00
121 lines
2.1 KiB
Go
121 lines
2.1 KiB
Go
package sample
|
|
|
|
import (
|
|
"cmp"
|
|
"math"
|
|
"slices"
|
|
|
|
pq "github.com/emirpasic/gods/v2/queues/priorityqueue"
|
|
)
|
|
|
|
type Transform interface {
|
|
Apply([]float64) []float64
|
|
}
|
|
|
|
// TODO(parthsareen): potentially cache softmax values
|
|
func softmax(logits []float64) []float64 {
|
|
var sum float64
|
|
probs := make([]float64, len(logits))
|
|
for i, v := range logits {
|
|
probs[i] = math.Exp(v)
|
|
sum += probs[i]
|
|
}
|
|
|
|
for i := range probs {
|
|
probs[i] /= sum
|
|
}
|
|
|
|
return probs
|
|
}
|
|
|
|
type Temperature float64
|
|
|
|
func (t Temperature) Apply(logits []float64) []float64 {
|
|
temp := math.Max(float64(t), 1e-7)
|
|
|
|
// subtracting max logit to avoid under/overflow
|
|
maxLogit := slices.Max(logits)
|
|
for i := range logits {
|
|
logits[i] = (logits[i] - maxLogit) / temp
|
|
}
|
|
|
|
return logits
|
|
}
|
|
|
|
type logitMap struct {
|
|
index int
|
|
logit float64
|
|
}
|
|
|
|
type TopK int
|
|
|
|
// TODO(parthsareen): avoid having to check all logits after this transform
|
|
func (k TopK) Apply(logits []float64) []float64 {
|
|
if int(k) >= len(logits) {
|
|
return logits
|
|
}
|
|
q := pq.NewWith(func(a, b logitMap) int {
|
|
return -cmp.Compare(a.logit, b.logit)
|
|
})
|
|
|
|
for i, logit := range logits {
|
|
q.Enqueue(logitMap{index: i, logit: logit})
|
|
}
|
|
|
|
validLogits := make(map[int]float64)
|
|
for range k {
|
|
logitMap, _ := q.Dequeue()
|
|
validLogits[logitMap.index] = logitMap.logit
|
|
}
|
|
|
|
for i := range logits {
|
|
if _, ok := validLogits[i]; !ok {
|
|
logits[i] = math.Inf(-1)
|
|
}
|
|
}
|
|
|
|
return logits
|
|
}
|
|
|
|
type TopP float64
|
|
|
|
func (p TopP) Apply(logits []float64) []float64 {
|
|
probs := softmax(logits)
|
|
indices := make([]int, len(probs))
|
|
for i := range indices {
|
|
indices[i] = i
|
|
}
|
|
|
|
// sort in descending order
|
|
slices.SortFunc(indices, func(i, j int) int {
|
|
return cmp.Compare(probs[j], probs[i])
|
|
})
|
|
|
|
var sum float64
|
|
for i, idx := range indices {
|
|
sum += probs[idx]
|
|
if sum > float64(p) {
|
|
for _, idx := range indices[i+1:] {
|
|
logits[idx] = math.Inf(-1)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
return logits
|
|
}
|
|
|
|
type MinP float64
|
|
|
|
func (p MinP) Apply(logits []float64) []float64 {
|
|
probs := softmax(logits)
|
|
threshold := slices.Max(probs) * float64(p)
|
|
|
|
for i, prob := range probs {
|
|
if prob < threshold {
|
|
logits[i] = math.Inf(-1)
|
|
}
|
|
}
|
|
|
|
return logits
|
|
}
|