mirror of
https://github.com/ollama/ollama.git
synced 2025-03-26 09:42:10 +01:00
239 lines
4.5 KiB
Go
239 lines
4.5 KiB
Go
package sample
|
|
|
|
import (
|
|
"math"
|
|
"math/rand/v2"
|
|
"testing"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
)
|
|
|
|
func TestWeighted(t *testing.T) {
|
|
got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
want := int32(1)
|
|
if want != got {
|
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
|
}
|
|
|
|
got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
|
|
if err == nil {
|
|
t.Error("expected error for no valid tokens, got index", got)
|
|
}
|
|
|
|
seed := uint64(42)
|
|
got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
// With seed 42, we expect a consistent sample
|
|
want = int32(3) // This will be deterministic due to the seed
|
|
if want != got {
|
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
|
}
|
|
}
|
|
|
|
type testTransform struct {
|
|
id int
|
|
callOrder *[]int
|
|
}
|
|
|
|
func (ts *testTransform) Apply(logits []float64) []float64 {
|
|
if ts.callOrder != nil {
|
|
*ts.callOrder = append(*ts.callOrder, ts.id)
|
|
}
|
|
return logits
|
|
}
|
|
|
|
func TestSample(t *testing.T) {
|
|
input := []float32{1, 2, 3, 4}
|
|
|
|
var callOrder []int
|
|
mock1 := &testTransform{
|
|
id: 1,
|
|
callOrder: &callOrder,
|
|
}
|
|
mock2 := &testTransform{
|
|
id: 2,
|
|
callOrder: &callOrder,
|
|
}
|
|
mock3 := &testTransform{
|
|
id: 3,
|
|
callOrder: &callOrder,
|
|
}
|
|
|
|
got, err := Greedy(mock1, mock2, mock3).Sample(input)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
|
|
want := int32(3) // Greedy sampler should pick highest logit
|
|
if want != got {
|
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
|
}
|
|
wantOrder := []int{1, 2, 3}
|
|
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
|
|
t.Errorf("call order mismatch (-want +got):\n%s", diff)
|
|
}
|
|
|
|
callOrder = nil
|
|
|
|
_, err = Weighted(nil, mock1, mock2, mock3).Sample(input)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
wantOrder = []int{1, 2, 3}
|
|
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
|
|
t.Errorf("call order mismatch (-want +got):\n%s", diff)
|
|
}
|
|
}
|
|
|
|
func TestNewSampler(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
temperature float32
|
|
topK int
|
|
topP float32
|
|
minP float32
|
|
seed int
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "no transforms",
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "temperature",
|
|
temperature: 0.5,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "invalid temperature negative",
|
|
temperature: -1,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "invalid temperature too high",
|
|
temperature: 2.1,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "top k",
|
|
topK: 10,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "invalid top k negative",
|
|
topK: -1,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "top p",
|
|
topP: 0.9,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "invalid top p negative",
|
|
topP: -0.1,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "invalid top p one",
|
|
topP: 1.0,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "min p",
|
|
minP: 0.2,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "invalid min p negative",
|
|
minP: -0.1,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "invalid min p one",
|
|
minP: 1.0,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "seed",
|
|
seed: 42,
|
|
wantErr: true, // seed alone is not valid without other transforms
|
|
},
|
|
{
|
|
name: "default values",
|
|
temperature: 0.8,
|
|
topK: 40,
|
|
topP: 0.9,
|
|
minP: 0.0,
|
|
seed: 0,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "all zeroes",
|
|
temperature: 0.0,
|
|
topK: 0,
|
|
topP: 0.0,
|
|
minP: 0.0,
|
|
seed: 0,
|
|
wantErr: true, // all zeroes means no transforms
|
|
},
|
|
{
|
|
name: "all transforms",
|
|
temperature: 0.8,
|
|
topK: 50,
|
|
topP: 0.95,
|
|
minP: 0.1,
|
|
seed: 42,
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
_, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func BenchmarkSample(b *testing.B) {
|
|
transforms := []Transform{
|
|
Temperature(0.5),
|
|
TopK(10),
|
|
TopP(0.9),
|
|
MinP(0.2),
|
|
}
|
|
|
|
samplers := map[string]Sampler{
|
|
"Greedy": Greedy(transforms...),
|
|
"Weighted": Weighted(nil, transforms...),
|
|
}
|
|
|
|
logits := make([]float32, 1<<16)
|
|
for i := range logits {
|
|
logits[i] = rand.Float32()
|
|
}
|
|
|
|
for name, s := range samplers {
|
|
b.Run(name, func(b *testing.B) {
|
|
b.ResetTimer()
|
|
for range b.N {
|
|
if _, err := s.Sample(logits); err != nil {
|
|
b.Error(err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|