diff --git a/sample/greedy.go b/sample/greedy.go index 206f5544d..4d110f021 100644 --- a/sample/greedy.go +++ b/sample/greedy.go @@ -8,6 +8,6 @@ func Greedy() Sampler { return greedy{} } -func (s greedy) Sample(t []float64) ([]float64, error) { - return []float64{float64(floats.MaxIdx(t))}, nil +func (s greedy) Sample(t []float64) (int, error) { + return floats.MaxIdx(t), nil } diff --git a/sample/sample.go b/sample/sample.go index 714d994de..d873e7cee 100644 --- a/sample/sample.go +++ b/sample/sample.go @@ -10,13 +10,30 @@ import ( "gonum.org/v1/gonum/stat/sampleuv" ) +type Transform interface { + Apply([]float64) ([]float64, error) +} + type Sampler interface { - Sample([]float64) ([]float64, error) + Sample([]float64) (int, error) +} + +type SamplerConfig struct { + transforms []Transform + sampler Sampler +} + +// NewSampler creates a sampler with the given transforms and sampling method +func NewSampler(transforms []Transform, sampler Sampler) *SamplerConfig { + return &SamplerConfig{ + transforms: transforms, + sampler: sampler, + } } type Temperature float64 -func (t Temperature) Sample(logits []float64) ([]float64, error) { +func (t Temperature) Apply(logits []float64) ([]float64, error) { if t < 0 || t > 2 { return nil, errors.New("temperature must be between 0 and 2") } @@ -34,15 +51,16 @@ func (t Temperature) Sample(logits []float64) ([]float64, error) { type softmax struct{} -func Softmax() Sampler { +func Softmax() Transform { return softmax{} } -func (softmax) Sample(logits []float64) ([]float64, error) { - return computeSoftmax(logits) +func (softmax) Apply(logits []float64) ([]float64, error) { + return computeSoftmax(logits), nil } -func computeSoftmax(logits []float64) ([]float64, error) { +// TODO: cache softmax values +func computeSoftmax(logits []float64) []float64 { copiedLogits := make([]float64, len(logits)) copy(copiedLogits, logits) for i := range copiedLogits { @@ -52,12 +70,12 @@ func computeSoftmax(logits []float64) ([]float64, error) { floatSum := floats.Sum(copiedLogits) floats.Scale(1.0/floatSum, copiedLogits) - return copiedLogits, nil + return copiedLogits } type TopK int -func (k TopK) Sample(logits []float64) ([]float64, error) { +func (k TopK) Apply(logits []float64) ([]float64, error) { if k <= 0 { return nil, errors.New("k must be positive") } @@ -76,23 +94,20 @@ func (k TopK) Sample(logits []float64) ([]float64, error) { }) for _, idx := range indices[k:] { - logits[idx] = math.NaN() + logits[idx] = math.Inf(-1) } return logits, nil } -type TopP float32 +type TopP float64 -func (p TopP) Sample(logits []float64) ([]float64, error) { +func (p TopP) Apply(logits []float64) ([]float64, error) { if p <= 0 || p >= 1 { return nil, errors.New("p must be between 0 and 1") } - probs, err := computeSoftmax(logits) - if err != nil { - return nil, err - } + probs := computeSoftmax(logits) indices := make([]int, len(probs)) for i := range indices { @@ -104,12 +119,12 @@ func (p TopP) Sample(logits []float64) ([]float64, error) { return cmp.Compare(probs[j], probs[i]) }) - cumSum := 0.0 + var cumSum float64 for i, idx := range indices { cumSum += probs[idx] if cumSum > float64(p) { for _, idx := range indices[i+1:] { - logits[idx] = math.NaN() + logits[idx] = math.Inf(-1) } break } @@ -117,17 +132,14 @@ func (p TopP) Sample(logits []float64) ([]float64, error) { return logits, nil } -type MinP float32 +type MinP float64 -func (p MinP) Sample(logits []float64) ([]float64, error) { +func (p MinP) Apply(logits []float64) ([]float64, error) { if p <= 0 || p >= 1 { return nil, errors.New("p must be between 0 and 1") } - probs, err := computeSoftmax(logits) - if err != nil { - return nil, err - } + probs := computeSoftmax(logits) copiedProbs := make([]float64, len(probs)) copy(copiedProbs, probs) @@ -138,7 +150,7 @@ func (p MinP) Sample(logits []float64) ([]float64, error) { for i := range probs { if probs[i] < probThreshold { - logits[i] = math.NaN() + logits[i] = math.Inf(-1) } } @@ -151,48 +163,51 @@ func Weighed() Sampler { return weighed{} } -func (s weighed) Sample(logits []float64) ([]float64, error) { +// should return single value +func (s weighed) Sample(logits []float64) (int, error) { logitsCopy := make([]float64, 0, len(logits)) indices := make([]int, 0, len(logits)) // the uv sampler does not support NaN values for i, logit := range logits { - if !math.IsNaN(logit) { + if !math.IsInf(logit, -1) { logitsCopy = append(logitsCopy, logit) indices = append(indices, i) } } if len(logitsCopy) == 0 { - return nil, errors.New("no valid tokens found") + return -1, errors.New("no valid tokens found") } - softmax, err := computeSoftmax(logitsCopy) - if err != nil { - return nil, err - } + softmax := computeSoftmax(logitsCopy) w := sampleuv.NewWeighted(softmax, nil) - if v, ok := w.Take(); ok { + if idx, ok := w.Take(); ok { // returns the token ID - return []float64{float64(indices[v])}, nil + return indices[idx], nil } - return nil, errors.New("weighed sampler failed") + return -1, errors.New("weighed sampler failed") } -func Sample(logits []float64, samplers ...Sampler) ([]float64, error) { +// Sample applies transforms and samples a token ID +func (s *SamplerConfig) Sample(input []float32) (int, error) { + logits := make([]float64, len(input)) + for i, v := range input { + logits[i] = float64(v) + } + var err error - for _, sampler := range samplers { - if sampler == Temperature(0) { + for _, t := range s.transforms { + if t == Temperature(0) { // early return with greedy if temperature is 0 - logits, err = Greedy().Sample(logits) - if err != nil { - return nil, err - } - return logits, nil + s.sampler = Greedy() + break } - logits, err = sampler.Sample(logits) + + logits, err = t.Apply(logits) if err != nil { - return nil, err + return -1, err } } - return logits, nil + + return s.sampler.Sample(logits) } diff --git a/sample/sample_test.go b/sample/sample_test.go index 4039c29cd..78c7209e7 100644 --- a/sample/sample_test.go +++ b/sample/sample_test.go @@ -10,7 +10,7 @@ import ( ) func TestTemperature(t *testing.T) { - logits, err := Temperature(0.5).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) + logits, err := Temperature(0.5).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) if err != nil { t.Fatal(err) } @@ -19,16 +19,16 @@ func TestTemperature(t *testing.T) { t.Fatalf("got: %v, want: %v", logits, want) } - if _, err := Temperature(-1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil { + if _, err := Temperature(-1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil { t.Fatalf("expected error for temperature=-1, got %v", logits) } - if _, err := Temperature(2.1).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil { + if _, err := Temperature(2.1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil { t.Fatalf("expected error for temperature=2.1, got %v", logits) } } func TestSoftmax(t *testing.T) { - probs, err := Softmax().Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) + probs, err := Softmax().Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) if err != nil { t.Fatal(err) } @@ -40,95 +40,101 @@ func TestSoftmax(t *testing.T) { } func TestTopK(t *testing.T) { - logits, err := TopK(3).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) + logits, err := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) if err != nil { t.Fatal(err) } - expectedlogits := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), 1, 2, 4} + expectedlogits := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4} if !floats.Same(logits, expectedlogits) { t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits) } - logits, err = TopK(0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) + logits, err = TopK(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) if err == nil { t.Fatalf("expected error for k=0, got %v", logits) } + + logits, err = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) + if err != nil { + t.Fatal(err) + } + expectedlogits = []float64{-3, -2, -1, 0, 1, 2, 4} + if !floats.Same(logits, expectedlogits) { + t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits) + } } func TestTopP(t *testing.T) { - logits, err := TopP(0.9).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) + logits, err := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) if err != nil { t.Fatal(err) } - want := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 2, 4} + want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4} if !floats.Same(logits, want) { t.Fatalf("got: %v, want: %v", logits, want) } - logits, err = TopP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) + logits, err = TopP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) if err == nil { t.Fatalf("expected error for p=1.0, got %v", logits) } - logits, err = TopP(0.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 4}) + logits, err = TopP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) if err == nil { t.Fatalf("expected error for p=0.0, got %v", logits) } } func TestMinP(t *testing.T) { - logits, err := MinP(0.2).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) + logits, err := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) if err != nil { t.Fatal(err) } - want := []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), 3, 4} + want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 3, 4} if !floats.Same(logits, want) { t.Fatalf("got: %v, want: %v", logits, want) } - logits, err = MinP(1.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) + logits, err = MinP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) if err == nil { t.Fatalf("expected error for p=1.0, got %v", logits) } - logits, err = MinP(0.0).Sample([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) + logits, err = MinP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4}) if err == nil { t.Fatalf("expected error for p=0.0, got %v", logits) } } func TestWeighed(t *testing.T) { - logits, err := Weighed().Sample([]float64{math.NaN(), 2, math.NaN(), math.NaN()}) + idx, err := Weighed().Sample([]float64{math.Inf(-1), 2, math.Inf(-1), math.Inf(-1)}) if err != nil { t.Fatal(err) } - want := []float64{1} - if !floats.Equal(logits, want) { - t.Fatalf("got: %v, want: %v", logits, want) + want := 1 + if idx != want { + t.Fatalf("got: %v, want: %v", idx, want) } - logits, err = Weighed().Sample([]float64{math.NaN(), math.NaN(), math.NaN()}) + idx, err = Weighed().Sample([]float64{math.Inf(-1), math.Inf(-1), math.Inf(-1)}) if err == nil { - t.Fatalf("expected error for no valid tokens, got %v", logits) + t.Fatalf("expected error for no valid tokens, got %v", idx) } } func TestSample(t *testing.T) { - input := []float64{1, 2, 3, 4} - want := []float64{1, 2, 3, 4} + input := []float32{1, 2, 3, 4} var callOrder []int - mock1 := &testSampler{ - id: 1, - callOrder: &callOrder, - returnVals: want, + mock1 := &testTransform{ + id: 1, + callOrder: &callOrder, } - mock2 := &testSampler{ - id: 2, - callOrder: &callOrder, - returnVals: want, + mock2 := &testTransform{ + id: 2, + callOrder: &callOrder, } - mock3 := &testSampler{ - id: 3, - callOrder: &callOrder, - returnVals: want, + mock3 := &testTransform{ + id: 3, + callOrder: &callOrder, } + sampler := NewSampler([]Transform{mock1, mock2, mock3}, Greedy()) - got, err := Sample(input, mock1, mock2, mock3) + got, err := sampler.Sample(input) if err != nil { t.Fatal(err) } @@ -137,43 +143,45 @@ func TestSample(t *testing.T) { t.Errorf("got %v, want %v", callOrder, []int{1, 2, 3}) } - if !floats.Equal(got, want) { + want := 3 // Greedy sampler should pick highest logit + if got != want { t.Errorf("got %v, want %v", got, want) } - errMock := &testSampler{ + errMock := &testTransform{ returnErr: fmt.Errorf("mock error"), } - _, err = Sample(input, mock1, errMock, mock2) + sampler = NewSampler([]Transform{mock1, errMock, mock2}, Greedy()) + _, err = sampler.Sample(input) if err == nil { t.Error("Expected error from sampler") } } -type testSampler struct { - id int - callOrder *[]int - returnVals []float64 - returnErr error +type testTransform struct { + id int + callOrder *[]int + returnErr error } -func (ts *testSampler) Sample(logits []float64) ([]float64, error) { +func (ts *testTransform) Apply(logits []float64) ([]float64, error) { if ts.callOrder != nil { *ts.callOrder = append(*ts.callOrder, ts.id) } if ts.returnErr != nil { return nil, ts.returnErr } - return ts.returnVals, nil + return logits, nil } func TestSampleTemperatureZero(t *testing.T) { - logits, err := Sample([]float64{1, 2, 3, 4}, Temperature(0)) + sampler := NewSampler([]Transform{Temperature(0)}, Greedy()) + got, err := sampler.Sample([]float32{1, 2, 3, 4}) if err != nil { t.Fatal(err) } - want := []float64{3} - if !floats.Equal(logits, want) { - t.Fatalf("got: %v, want: %v", logits, want) + want := 3 // Greedy sampler should pick highest logit index + if got != want { + t.Fatalf("got: %v, want: %v", got, want) } }