package sample import ( "cmp" "math" "slices" pq "github.com/emirpasic/gods/v2/queues/priorityqueue" ) type Transform interface { Apply([]float64) []float64 } // TODO(parthsareen): potentially cache softmax values func softmax(logits []float64) []float64 { var sum float64 probs := make([]float64, len(logits)) for i, v := range logits { probs[i] = math.Exp(v) sum += probs[i] } for i := range probs { probs[i] /= sum } return probs } type Temperature float64 func (t Temperature) Apply(logits []float64) []float64 { temp := math.Max(float64(t), 1e-7) // subtracting max logit to avoid under/overflow maxLogit := slices.Max(logits) for i := range logits { logits[i] = (logits[i] - maxLogit) / temp } return logits } type logitMap struct { index int logit float64 } type TopK int // TODO(parthsareen): avoid having to check all logits after this transform func (k TopK) Apply(logits []float64) []float64 { if int(k) >= len(logits) { return logits } q := pq.NewWith(func(a, b logitMap) int { return -cmp.Compare(a.logit, b.logit) }) for i, logit := range logits { q.Enqueue(logitMap{index: i, logit: logit}) } validLogits := make(map[int]float64) for range k { logitMap, _ := q.Dequeue() validLogits[logitMap.index] = logitMap.logit } for i := range logits { if _, ok := validLogits[i]; !ok { logits[i] = math.Inf(-1) } } return logits } type TopP float64 func (p TopP) Apply(logits []float64) []float64 { probs := softmax(logits) indices := make([]int, len(probs)) for i := range indices { indices[i] = i } // sort in descending order slices.SortFunc(indices, func(i, j int) int { return cmp.Compare(probs[j], probs[i]) }) var sum float64 for i, idx := range indices { sum += probs[idx] if sum > float64(p) { for _, idx := range indices[i+1:] { logits[idx] = math.Inf(-1) } break } } return logits } type MinP float64 func (p MinP) Apply(logits []float64) []float64 { probs := softmax(logits) threshold := slices.Max(probs) * float64(p) for i, prob := range probs { if prob < threshold { logits[i] = math.Inf(-1) } } return logits }