package sample import ( "math/rand/v2" "testing" ) func TestWeighted(t *testing.T) { logits := []float32{-10, 3, -10, -10} sampler := NewSampler(0, 0, 0, 0, 0, nil) got, err := sampler.Sample(logits) if err != nil { t.Error(err) return } want := int32(1) if want != got { t.Errorf("index mismatch: want %d, got %d", want, got) } logits = []float32{-100, -10, 0, 10} sampler = NewSampler(0, 0, 0, 0, 0, nil) got, err = sampler.Sample(logits) if err != nil { t.Error(err) return } want = int32(3) // Should pick highest probability with this r value if want != got { t.Errorf("index mismatch: want %d, got %d", want, got) } } func BenchmarkSample(b *testing.B) { samplers := map[string]Sampler{ "Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy "Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil), } // Generate random logits for benchmarking logits := make([]float32, 1<<16) for i := range logits { logits[i] = rand.Float32() } for name, s := range samplers { b.Run(name, func(b *testing.B) { b.ResetTimer() for b.Loop() { if _, err := s.Sample(logits); err != nil { b.Fatalf("error sampling: %v", err) } } }) } }