mirror of
https://github.com/ollama/ollama.git
synced 2025-03-29 11:11:47 +01:00
sample: add sampling package for new engine (#8410)
This commit is contained in:
parent
314573bfe8
commit
0b7e1676eb
@ -65,8 +65,8 @@ type Sequence struct {
|
||||
// number of tokens to predict
|
||||
numPredict int
|
||||
|
||||
// set of samplers to run on generated logits
|
||||
samplers []sample.Sampler
|
||||
// sampler with transforms to run on generated logits
|
||||
sampler sample.Sampler
|
||||
|
||||
// channel to send back the embedding if embedding only
|
||||
embedding chan []float32
|
||||
@ -93,7 +93,7 @@ type NewSequenceParams struct {
|
||||
numPredict int
|
||||
stop []string
|
||||
numKeep int32
|
||||
samplers []sample.Sampler
|
||||
sampler sample.Sampler
|
||||
embedding bool
|
||||
}
|
||||
|
||||
@ -136,7 +136,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
samplers: params.samplers,
|
||||
sampler: params.sampler,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
@ -393,13 +393,7 @@ func (s *Server) processBatch() error {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
}
|
||||
|
||||
f32s := modelOutput.Floats()
|
||||
|
||||
// TODO(jessegross): This will no longer be necessary once the sampling interface takes f32s
|
||||
logits := make([]float64, len(f32s))
|
||||
for i, f32 := range f32s {
|
||||
logits[i] = float64(f32)
|
||||
}
|
||||
logits := modelOutput.Floats()
|
||||
|
||||
for i, seq := range s.seqs {
|
||||
if seq == nil {
|
||||
@ -433,14 +427,12 @@ func (s *Server) processBatch() error {
|
||||
}
|
||||
|
||||
// sample a token
|
||||
vocabSize := len(f32s) / len(options.Outputs)
|
||||
tokens, err := sample.Sample(logits[seq.iBatch*vocabSize:(seq.iBatch+1)*vocabSize], seq.samplers...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
vocabSize := len(logits) / len(options.Outputs)
|
||||
|
||||
// TODO(jessegross): Sampler will output a single int32 in the future
|
||||
token := int32(tokens[0])
|
||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sample token: %w", err)
|
||||
}
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
@ -565,27 +557,6 @@ type CompletionResponse struct {
|
||||
Timings Timings `json:"timings"`
|
||||
}
|
||||
|
||||
func getSamplers(_ CompletionRequest) []sample.Sampler {
|
||||
// TODO(jessegross): Waiting for sampling code
|
||||
|
||||
/*samplingParams.TopK = req.TopK
|
||||
samplingParams.TopP = req.TopP
|
||||
samplingParams.MinP = req.MinP
|
||||
samplingParams.TypicalP = req.TypicalP
|
||||
samplingParams.Temp = req.Temperature
|
||||
samplingParams.RepeatLastN = req.RepeatLastN
|
||||
samplingParams.PenaltyRepeat = req.RepeatPenalty
|
||||
samplingParams.PenaltyFreq = req.FrequencyPenalty
|
||||
samplingParams.PenaltyPresent = req.PresencePenalty
|
||||
samplingParams.Mirostat = req.Mirostat
|
||||
samplingParams.MirostatTau = req.MirostatTau
|
||||
samplingParams.MirostatEta = req.MirostatEta
|
||||
samplingParams.Seed = uint32(req.Seed)
|
||||
samplingParams.Grammar = req.Grammar*/
|
||||
|
||||
return []sample.Sampler{sample.Greedy()}
|
||||
}
|
||||
|
||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
var req CompletionRequest
|
||||
req.Options = Options(api.DefaultOptions())
|
||||
@ -604,11 +575,23 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
sampler, err := sample.NewSampler(
|
||||
req.Temperature,
|
||||
req.TopK,
|
||||
req.TopP,
|
||||
req.MinP,
|
||||
req.Seed,
|
||||
)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create sampler: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.NumPredict,
|
||||
stop: req.Stop,
|
||||
numKeep: int32(req.NumKeep),
|
||||
samplers: getSamplers(req),
|
||||
sampler: sampler,
|
||||
embedding: false,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -1,13 +0,0 @@
|
||||
package sample
|
||||
|
||||
import "gonum.org/v1/gonum/floats"
|
||||
|
||||
type greedy struct{}
|
||||
|
||||
func Greedy() Sampler {
|
||||
return greedy{}
|
||||
}
|
||||
|
||||
func (s greedy) Sample(t []float64) ([]float64, error) {
|
||||
return []float64{float64(floats.MaxIdx(t))}, nil
|
||||
}
|
@ -1,74 +0,0 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"gonum.org/v1/gonum/floats"
|
||||
"gonum.org/v1/gonum/stat/sampleuv"
|
||||
)
|
||||
|
||||
type Sampler interface {
|
||||
Sample([]float64) ([]float64, error)
|
||||
}
|
||||
|
||||
type Temperature float64
|
||||
|
||||
func (s Temperature) Sample(t []float64) ([]float64, error) {
|
||||
floats.Div(t, slices.Repeat([]float64{float64(s)}, len(t)))
|
||||
return t, nil
|
||||
}
|
||||
|
||||
type softmax struct{}
|
||||
|
||||
func Softmax() Sampler {
|
||||
return softmax{}
|
||||
}
|
||||
|
||||
func (softmax) Sample(t []float64) ([]float64, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
type TopK int
|
||||
|
||||
func (s TopK) Sample(t []float64) ([]float64, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
type TopP float32
|
||||
|
||||
func (s TopP) Sample(t []float64) ([]float64, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
type MinP float32
|
||||
|
||||
func (s MinP) Sample(t []float64) ([]float64, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
type weighed struct{}
|
||||
|
||||
func Weighed() Sampler {
|
||||
return weighed{}
|
||||
}
|
||||
|
||||
func (s weighed) Sample(t []float64) ([]float64, error) {
|
||||
w := sampleuv.NewWeighted(t, nil)
|
||||
if v, ok := w.Take(); ok {
|
||||
return []float64{float64(v)}, nil
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func Sample(floats []float64, samplers ...Sampler) ([]float64, error) {
|
||||
var err error
|
||||
for _, sampler := range samplers {
|
||||
floats, err = sampler.Sample(floats)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return floats, nil
|
||||
}
|
139
sample/samplers.go
Normal file
139
sample/samplers.go
Normal file
@ -0,0 +1,139 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"golang.org/x/exp/rand"
|
||||
"gonum.org/v1/gonum/stat/sampleuv"
|
||||
)
|
||||
|
||||
type Sampler interface {
|
||||
Sample([]float32) (int32, error)
|
||||
}
|
||||
|
||||
type weighted struct {
|
||||
src rand.Source
|
||||
transforms []Transform
|
||||
}
|
||||
|
||||
// TODO(parthsareen): remove uv sample dependency https://github.com/ollama/ollama/issues/9279
|
||||
func Weighted(seed *uint64, transforms ...Transform) Sampler {
|
||||
var src rand.Source
|
||||
if seed != nil {
|
||||
src = rand.NewSource(*seed)
|
||||
}
|
||||
return weighted{src: src, transforms: transforms}
|
||||
}
|
||||
|
||||
func (s weighted) Sample(logits []float32) (int32, error) {
|
||||
logits64 := make([]float64, len(logits))
|
||||
for i, v := range logits {
|
||||
logits64[i] = float64(v)
|
||||
}
|
||||
|
||||
for _, t := range s.transforms {
|
||||
logits64 = t.Apply(logits64)
|
||||
}
|
||||
|
||||
logitsCopy := make([]float64, 0, len(logits))
|
||||
indices := make([]int, 0, len(logits))
|
||||
for i, logit := range logits64 {
|
||||
if !math.IsInf(logit, -1) {
|
||||
logitsCopy = append(logitsCopy, logit)
|
||||
indices = append(indices, i)
|
||||
}
|
||||
}
|
||||
|
||||
if len(logitsCopy) == 0 {
|
||||
return -1, errors.New("no valid logits found for weighed sampling")
|
||||
}
|
||||
|
||||
probs := softmax(logitsCopy)
|
||||
w := sampleuv.NewWeighted(probs, s.src)
|
||||
if idx, ok := w.Take(); ok {
|
||||
return int32(indices[idx]), nil
|
||||
}
|
||||
return -1, errors.New("weighed sampler failed, no valid token found")
|
||||
}
|
||||
|
||||
type greedy struct {
|
||||
transforms []Transform
|
||||
}
|
||||
|
||||
func Greedy(transforms ...Transform) Sampler {
|
||||
return greedy{transforms: transforms}
|
||||
}
|
||||
|
||||
func (s greedy) Sample(logits []float32) (int32, error) {
|
||||
logits64 := make([]float64, len(logits))
|
||||
for i, v := range logits {
|
||||
logits64[i] = float64(v)
|
||||
}
|
||||
|
||||
for _, t := range s.transforms {
|
||||
logits64 = t.Apply(logits64)
|
||||
}
|
||||
|
||||
var maxIdx int
|
||||
var maxLogit float64
|
||||
for i, logit := range logits64 {
|
||||
if logit > maxLogit {
|
||||
maxLogit = logit
|
||||
maxIdx = i
|
||||
}
|
||||
}
|
||||
|
||||
if maxLogit == math.Inf(-1) {
|
||||
return -1, errors.New("no valid logits found for greedy sampling")
|
||||
}
|
||||
|
||||
return int32(maxIdx), nil
|
||||
}
|
||||
|
||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) {
|
||||
transforms := []Transform{}
|
||||
if temperature < 0 || temperature > 2 {
|
||||
return nil, errors.New("temperature must be between 0 and 2")
|
||||
}
|
||||
|
||||
if temperature != 0 {
|
||||
transforms = append(transforms, Temperature(temperature))
|
||||
}
|
||||
|
||||
if topK != 0 {
|
||||
if topK <= 0 {
|
||||
return nil, errors.New("topK must be greater than 0")
|
||||
}
|
||||
transforms = append(transforms, TopK(topK))
|
||||
}
|
||||
|
||||
if topP != 0 {
|
||||
if topP < 0 || topP >= 1 {
|
||||
return nil, errors.New("topP must be between 0 and 1")
|
||||
}
|
||||
transforms = append(transforms, TopP(topP))
|
||||
}
|
||||
|
||||
if minP != 0 {
|
||||
if minP < 0 || minP >= 1 {
|
||||
return nil, errors.New("minP must be between 0 and 1")
|
||||
}
|
||||
transforms = append(transforms, MinP(minP))
|
||||
}
|
||||
|
||||
if len(transforms) == 0 {
|
||||
return nil, errors.New("at least one transform is required")
|
||||
}
|
||||
|
||||
if temperature == 0 {
|
||||
return Greedy(transforms...), nil
|
||||
}
|
||||
|
||||
if seed != 0 {
|
||||
seed64 := uint64(seed)
|
||||
return Weighted(&seed64, transforms...), nil
|
||||
}
|
||||
return Weighted(nil, transforms...), nil
|
||||
}
|
238
sample/samplers_test.go
Normal file
238
sample/samplers_test.go
Normal file
@ -0,0 +1,238 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestWeighted(t *testing.T) {
|
||||
got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
want := int32(1)
|
||||
if want != got {
|
||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||
}
|
||||
|
||||
got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
|
||||
if err == nil {
|
||||
t.Error("expected error for no valid tokens, got index", got)
|
||||
}
|
||||
|
||||
seed := uint64(42)
|
||||
got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
// With seed 42, we expect a consistent sample
|
||||
want = int32(3) // This will be deterministic due to the seed
|
||||
if want != got {
|
||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
type testTransform struct {
|
||||
id int
|
||||
callOrder *[]int
|
||||
}
|
||||
|
||||
func (ts *testTransform) Apply(logits []float64) []float64 {
|
||||
if ts.callOrder != nil {
|
||||
*ts.callOrder = append(*ts.callOrder, ts.id)
|
||||
}
|
||||
return logits
|
||||
}
|
||||
|
||||
func TestSample(t *testing.T) {
|
||||
input := []float32{1, 2, 3, 4}
|
||||
|
||||
var callOrder []int
|
||||
mock1 := &testTransform{
|
||||
id: 1,
|
||||
callOrder: &callOrder,
|
||||
}
|
||||
mock2 := &testTransform{
|
||||
id: 2,
|
||||
callOrder: &callOrder,
|
||||
}
|
||||
mock3 := &testTransform{
|
||||
id: 3,
|
||||
callOrder: &callOrder,
|
||||
}
|
||||
|
||||
got, err := Greedy(mock1, mock2, mock3).Sample(input)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
want := int32(3) // Greedy sampler should pick highest logit
|
||||
if want != got {
|
||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||
}
|
||||
wantOrder := []int{1, 2, 3}
|
||||
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
|
||||
t.Errorf("call order mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
callOrder = nil
|
||||
|
||||
_, err = Weighted(nil, mock1, mock2, mock3).Sample(input)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
wantOrder = []int{1, 2, 3}
|
||||
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
|
||||
t.Errorf("call order mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSampler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
temperature float32
|
||||
topK int
|
||||
topP float32
|
||||
minP float32
|
||||
seed int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no transforms",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "temperature",
|
||||
temperature: 0.5,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid temperature negative",
|
||||
temperature: -1,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid temperature too high",
|
||||
temperature: 2.1,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "top k",
|
||||
topK: 10,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid top k negative",
|
||||
topK: -1,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "top p",
|
||||
topP: 0.9,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid top p negative",
|
||||
topP: -0.1,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid top p one",
|
||||
topP: 1.0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "min p",
|
||||
minP: 0.2,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid min p negative",
|
||||
minP: -0.1,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid min p one",
|
||||
minP: 1.0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "seed",
|
||||
seed: 42,
|
||||
wantErr: true, // seed alone is not valid without other transforms
|
||||
},
|
||||
{
|
||||
name: "default values",
|
||||
temperature: 0.8,
|
||||
topK: 40,
|
||||
topP: 0.9,
|
||||
minP: 0.0,
|
||||
seed: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "all zeroes",
|
||||
temperature: 0.0,
|
||||
topK: 0,
|
||||
topP: 0.0,
|
||||
minP: 0.0,
|
||||
seed: 0,
|
||||
wantErr: true, // all zeroes means no transforms
|
||||
},
|
||||
{
|
||||
name: "all transforms",
|
||||
temperature: 0.8,
|
||||
topK: 50,
|
||||
topP: 0.95,
|
||||
minP: 0.1,
|
||||
seed: 42,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSample(b *testing.B) {
|
||||
transforms := []Transform{
|
||||
Temperature(0.5),
|
||||
TopK(10),
|
||||
TopP(0.9),
|
||||
MinP(0.2),
|
||||
}
|
||||
|
||||
samplers := map[string]Sampler{
|
||||
"Greedy": Greedy(transforms...),
|
||||
"Weighted": Weighted(nil, transforms...),
|
||||
}
|
||||
|
||||
logits := make([]float32, 1<<16)
|
||||
for i := range logits {
|
||||
logits[i] = rand.Float32()
|
||||
}
|
||||
|
||||
for name, s := range samplers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
if _, err := s.Sample(logits); err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
120
sample/transforms.go
Normal file
120
sample/transforms.go
Normal file
@ -0,0 +1,120 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
pq "github.com/emirpasic/gods/v2/queues/priorityqueue"
|
||||
)
|
||||
|
||||
type Transform interface {
|
||||
Apply([]float64) []float64
|
||||
}
|
||||
|
||||
// TODO(parthsareen): potentially cache softmax values
|
||||
func softmax(logits []float64) []float64 {
|
||||
var sum float64
|
||||
probs := make([]float64, len(logits))
|
||||
for i, v := range logits {
|
||||
probs[i] = math.Exp(v)
|
||||
sum += probs[i]
|
||||
}
|
||||
|
||||
for i := range probs {
|
||||
probs[i] /= sum
|
||||
}
|
||||
|
||||
return probs
|
||||
}
|
||||
|
||||
type Temperature float64
|
||||
|
||||
func (t Temperature) Apply(logits []float64) []float64 {
|
||||
temp := math.Max(float64(t), 1e-7)
|
||||
|
||||
// subtracting max logit to avoid under/overflow
|
||||
maxLogit := slices.Max(logits)
|
||||
for i := range logits {
|
||||
logits[i] = (logits[i] - maxLogit) / temp
|
||||
}
|
||||
|
||||
return logits
|
||||
}
|
||||
|
||||
type logitMap struct {
|
||||
index int
|
||||
logit float64
|
||||
}
|
||||
|
||||
type TopK int
|
||||
|
||||
// TODO(parthsareen): avoid having to check all logits after this transform
|
||||
func (k TopK) Apply(logits []float64) []float64 {
|
||||
if int(k) >= len(logits) {
|
||||
return logits
|
||||
}
|
||||
q := pq.NewWith(func(a, b logitMap) int {
|
||||
return -cmp.Compare(a.logit, b.logit)
|
||||
})
|
||||
|
||||
for i, logit := range logits {
|
||||
q.Enqueue(logitMap{index: i, logit: logit})
|
||||
}
|
||||
|
||||
validLogits := make(map[int]float64)
|
||||
for range k {
|
||||
logitMap, _ := q.Dequeue()
|
||||
validLogits[logitMap.index] = logitMap.logit
|
||||
}
|
||||
|
||||
for i := range logits {
|
||||
if _, ok := validLogits[i]; !ok {
|
||||
logits[i] = math.Inf(-1)
|
||||
}
|
||||
}
|
||||
|
||||
return logits
|
||||
}
|
||||
|
||||
type TopP float64
|
||||
|
||||
func (p TopP) Apply(logits []float64) []float64 {
|
||||
probs := softmax(logits)
|
||||
indices := make([]int, len(probs))
|
||||
for i := range indices {
|
||||
indices[i] = i
|
||||
}
|
||||
|
||||
// sort in descending order
|
||||
slices.SortFunc(indices, func(i, j int) int {
|
||||
return cmp.Compare(probs[j], probs[i])
|
||||
})
|
||||
|
||||
var sum float64
|
||||
for i, idx := range indices {
|
||||
sum += probs[idx]
|
||||
if sum > float64(p) {
|
||||
for _, idx := range indices[i+1:] {
|
||||
logits[idx] = math.Inf(-1)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return logits
|
||||
}
|
||||
|
||||
type MinP float64
|
||||
|
||||
func (p MinP) Apply(logits []float64) []float64 {
|
||||
probs := softmax(logits)
|
||||
threshold := slices.Max(probs) * float64(p)
|
||||
|
||||
for i, prob := range probs {
|
||||
if prob < threshold {
|
||||
logits[i] = math.Inf(-1)
|
||||
}
|
||||
}
|
||||
|
||||
return logits
|
||||
}
|
80
sample/transforms_test.go
Normal file
80
sample/transforms_test.go
Normal file
@ -0,0 +1,80 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestTemperature(t *testing.T) {
|
||||
got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
|
||||
want := []float64{-4, -10, 0, -14, -6, -12, -8}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftmax(t *testing.T) {
|
||||
got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
|
||||
want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("probs mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopK(t *testing.T) {
|
||||
got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
|
||||
want = []float64{-3, -2, -1, 0, 1, 2, 4}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopP(t *testing.T) {
|
||||
got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinP(t *testing.T) {
|
||||
got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
|
||||
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTransform(b *testing.B) {
|
||||
transforms := map[string]Transform{
|
||||
"Temperature": Temperature(0.5),
|
||||
"TopK": TopK(10),
|
||||
"TopP": TopP(0.9),
|
||||
"MinP": MinP(0.2),
|
||||
}
|
||||
|
||||
logits := make([]float64, 1<<16)
|
||||
for i := range logits {
|
||||
logits[i] = rand.Float64()
|
||||
}
|
||||
|
||||
for name, transform := range transforms {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
transform.Apply(logits)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user