mirror of
https://github.com/ollama/ollama.git
synced 2025-06-11 04:30:51 +02:00
sample: improve ollama engine sampler performance (#9374)
This change bring in various interface cleanups along with greatly improving the performance of the sampler. Tested with llama3.2 on local machine. Improves performance from ~ 70 tokens/s -> 135 tokens/s with topK(40) enabled. Without topK performance is ~ 110 tokens/s
This commit is contained in:
parent
1f6986e919
commit
0682dae027
2
go.mod
2
go.mod
@ -25,7 +25,6 @@ require (
|
|||||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||||
golang.org/x/image v0.22.0
|
golang.org/x/image v0.22.0
|
||||||
golang.org/x/tools v0.30.0
|
golang.org/x/tools v0.30.0
|
||||||
gonum.org/v1/gonum v0.15.0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@ -45,6 +44,7 @@ require (
|
|||||||
github.com/xtgo/set v1.0.0 // indirect
|
github.com/xtgo/set v1.0.0 // indirect
|
||||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||||
|
gonum.org/v1/gonum v0.15.0 // indirect
|
||||||
gorgonia.org/vecf32 v0.9.0 // indirect
|
gorgonia.org/vecf32 v0.9.0 // indirect
|
||||||
gorgonia.org/vecf64 v0.9.0 // indirect
|
gorgonia.org/vecf64 v0.9.0 // indirect
|
||||||
)
|
)
|
||||||
|
@ -589,11 +589,19 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sampler := sample.NewSampler(
|
||||||
|
req.Temperature,
|
||||||
|
req.TopK,
|
||||||
|
req.TopP,
|
||||||
|
req.MinP,
|
||||||
|
req.Seed,
|
||||||
|
)
|
||||||
|
|
||||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
numPredict: req.NumPredict,
|
numPredict: req.NumPredict,
|
||||||
stop: req.Stop,
|
stop: req.Stop,
|
||||||
numKeep: int32(req.NumKeep),
|
numKeep: int32(req.NumKeep),
|
||||||
sampler: sample.Greedy(), // TODO: add support for different samplers when performance is optimized
|
sampler: sampler,
|
||||||
embedding: false,
|
embedding: false,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2,76 +2,103 @@ package sample
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"math"
|
"math/rand/v2"
|
||||||
|
"slices"
|
||||||
"golang.org/x/exp/rand"
|
|
||||||
"gonum.org/v1/gonum/stat/sampleuv"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Sampler is not thread-safe. Each goroutine should have its own instance
|
||||||
type Sampler interface {
|
type Sampler interface {
|
||||||
Sample([]float32) (int32, error)
|
Sample([]float32) (int32, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// logit represents information about a single token during sampling
|
||||||
|
type logit struct {
|
||||||
|
id int32 // The token's unique identifier
|
||||||
|
value float32 // The raw logit or probability from the model
|
||||||
|
}
|
||||||
|
|
||||||
type weighted struct {
|
type weighted struct {
|
||||||
src rand.Source
|
rng *rand.Rand
|
||||||
transforms []Transform
|
tokens []logit
|
||||||
|
topK int
|
||||||
|
topP float32
|
||||||
|
minP float32
|
||||||
|
temperature float32
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(parthsareen): remove uv sample dependency https://github.com/ollama/ollama/issues/9279
|
func (s *weighted) Sample(logits []float32) (int32, error) {
|
||||||
func Weighted(seed *uint64, transforms ...Transform) Sampler {
|
if len(s.tokens) < len(logits) {
|
||||||
var src rand.Source
|
s.tokens = make([]logit, len(logits))
|
||||||
if seed != nil {
|
|
||||||
src = rand.NewSource(*seed)
|
|
||||||
}
|
}
|
||||||
return weighted{src: src, transforms: transforms}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s weighted) Sample(logits []float32) (int32, error) {
|
tokens := s.tokens[:len(logits)]
|
||||||
logits64 := make([]float64, len(logits))
|
|
||||||
for i, v := range logits {
|
for i, v := range logits {
|
||||||
logits64[i] = float64(v)
|
tokens[i].id = int32(i)
|
||||||
|
tokens[i].value = v
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range s.transforms {
|
// Tokens are sorted by logits in TopK or SortTokens
|
||||||
logits64 = t.Apply(logits64)
|
if s.topK > 0 {
|
||||||
|
tokens = topK(tokens, s.topK)
|
||||||
|
} else {
|
||||||
|
sortLogits(tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
logitsCopy := make([]float64, 0, len(logits))
|
tokens = temperature(tokens, s.temperature)
|
||||||
indices := make([]int, 0, len(logits))
|
tokens = softmax(tokens)
|
||||||
for i, logit := range logits64 {
|
|
||||||
if !math.IsInf(logit, -1) {
|
tokens = topP(tokens, s.topP)
|
||||||
logitsCopy = append(logitsCopy, logit)
|
tokens = minP(tokens, s.minP)
|
||||||
indices = append(indices, i)
|
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
return -1, errors.New("no valid logits found for weighted sampling")
|
||||||
|
}
|
||||||
|
|
||||||
|
var r float32
|
||||||
|
if s.rng != nil {
|
||||||
|
r = s.rng.Float32()
|
||||||
|
} else {
|
||||||
|
r = rand.Float32()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate cumulative sum of probabilities
|
||||||
|
var sum float32
|
||||||
|
for i := range tokens {
|
||||||
|
sum += tokens[i].value
|
||||||
|
tokens[i].value = sum
|
||||||
|
}
|
||||||
|
r *= tokens[len(tokens)-1].value
|
||||||
|
|
||||||
|
idx, _ := slices.BinarySearchFunc(tokens, r, func(token logit, target float32) int {
|
||||||
|
// Compare cumulative probabilities
|
||||||
|
if token.value < target {
|
||||||
|
return -1
|
||||||
}
|
}
|
||||||
|
// First token that exceeds target
|
||||||
|
return 1
|
||||||
|
})
|
||||||
|
|
||||||
|
if idx >= len(tokens) {
|
||||||
|
idx = len(tokens) - 1
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(logitsCopy) == 0 {
|
return tokens[idx].id, nil
|
||||||
return -1, errors.New("no valid logits found for weighed sampling")
|
|
||||||
}
|
|
||||||
|
|
||||||
probs := softmax(logitsCopy)
|
|
||||||
w := sampleuv.NewWeighted(probs, s.src)
|
|
||||||
if idx, ok := w.Take(); ok {
|
|
||||||
return int32(indices[idx]), nil
|
|
||||||
}
|
|
||||||
return -1, errors.New("weighted sampler failed, no valid token found")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type greedy struct{}
|
type greedy struct{}
|
||||||
|
|
||||||
func Greedy() Sampler {
|
// Greedy sample returns the index of the maximum value in logits.
|
||||||
return greedy{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sample returns the index of the maximum value in logits.
|
|
||||||
func (s greedy) Sample(logits []float32) (int32, error) {
|
func (s greedy) Sample(logits []float32) (int32, error) {
|
||||||
if len(logits) == 0 {
|
if len(logits) == 0 {
|
||||||
return -1, errors.New("no logits provided for greedy sampling")
|
return -1, errors.New("no logits provided for greedy sampling")
|
||||||
}
|
}
|
||||||
|
|
||||||
maxIdx := 0
|
maxIdx := 0
|
||||||
for i := range logits {
|
maxVal := logits[0]
|
||||||
if logits[i] > logits[maxIdx] {
|
for i := 1; i < len(logits); i++ {
|
||||||
|
if logits[i] > maxVal {
|
||||||
|
maxVal = logits[i]
|
||||||
maxIdx = i
|
maxIdx = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -80,41 +107,40 @@ func (s greedy) Sample(logits []float32) (int32, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) {
|
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) Sampler {
|
||||||
if temperature == 0 {
|
if temperature == 0 {
|
||||||
return Greedy(), nil
|
return &greedy{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if temperature < 0 || temperature > 2 {
|
var rng *rand.Rand
|
||||||
return nil, errors.New("temperature must be between 0 and 2")
|
if seed != -1 {
|
||||||
|
// PCG requires two parameters: sequence and stream
|
||||||
|
// Use original seed for sequence
|
||||||
|
sequence := uint64(seed)
|
||||||
|
// Use golden ratio hash to generate statistically independent seeds
|
||||||
|
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
|
||||||
|
}
|
||||||
|
temperature = max(temperature, 1)
|
||||||
|
|
||||||
|
if topP < 0.0 {
|
||||||
|
topP = 0.0
|
||||||
|
}
|
||||||
|
if topP >= 1.0 {
|
||||||
|
topP = 1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
transforms := []Transform{Temperature(temperature)}
|
if minP < 0.0 {
|
||||||
|
minP = 0.0
|
||||||
if topK != 0 {
|
}
|
||||||
if topK <= 0 {
|
if minP >= 1.0 {
|
||||||
return nil, errors.New("topK must be greater than 0")
|
minP = 1.0
|
||||||
}
|
|
||||||
transforms = append(transforms, TopK(topK))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if topP != 0 {
|
return &weighted{
|
||||||
if topP < 0 || topP >= 1 {
|
rng: rng,
|
||||||
return nil, errors.New("topP must be between 0 and 1")
|
topK: topK,
|
||||||
}
|
topP: topP,
|
||||||
transforms = append(transforms, TopP(topP))
|
minP: minP,
|
||||||
|
temperature: temperature,
|
||||||
}
|
}
|
||||||
|
|
||||||
if minP != 0 {
|
|
||||||
if minP < 0 || minP >= 1 {
|
|
||||||
return nil, errors.New("minP must be between 0 and 1")
|
|
||||||
}
|
|
||||||
transforms = append(transforms, MinP(minP))
|
|
||||||
}
|
|
||||||
|
|
||||||
if seed >= 0 {
|
|
||||||
seed64 := uint64(seed)
|
|
||||||
return Weighted(&seed64, transforms...), nil
|
|
||||||
}
|
|
||||||
return Weighted(nil, transforms...), nil
|
|
||||||
}
|
}
|
||||||
|
104
sample/samplers_benchmark_test.go
Normal file
104
sample/samplers_benchmark_test.go
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkWeightedSampler(b *testing.B) {
|
||||||
|
sizes := []int{10, 100, 1000, 10000}
|
||||||
|
|
||||||
|
for _, size := range sizes {
|
||||||
|
b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
|
||||||
|
logits := make([]float32, size)
|
||||||
|
for i := range logits {
|
||||||
|
logits[i] = float32(rand.Float64()*10 - 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
sampler := NewSampler(0.8, 0, 0, 0, 42)
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
_, err := sampler.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Sampling failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
configs := []struct {
|
||||||
|
name string
|
||||||
|
temperature float32
|
||||||
|
topK int
|
||||||
|
topP float32
|
||||||
|
minP float32
|
||||||
|
seed int
|
||||||
|
}{
|
||||||
|
{"Greedy", 0, -1, 0, 0, -1},
|
||||||
|
{"Temperature", 0.8, -1, 0, 0, -1},
|
||||||
|
{"TopK", 0.8, 50, 0, 0, -1},
|
||||||
|
{"TopP", 0.8, -1, 0.9, 0, -1},
|
||||||
|
{"MinP", 0.8, -1, 0, 0.05, -1},
|
||||||
|
{"WithSeed", 0.8, 50, 0, 0, 42},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fixed size for common vocab size
|
||||||
|
size := 128000
|
||||||
|
logits := make([]float32, size)
|
||||||
|
for i := range logits {
|
||||||
|
logits[i] = float32(rand.Float64()*10 - 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range configs {
|
||||||
|
b.Run("Config"+tc.name, func(b *testing.B) {
|
||||||
|
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed)
|
||||||
|
sampler.Sample(logits)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
_, err := sampler.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Sampling failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with combined transforms separately - topK influences performance greatly
|
||||||
|
b.Run("TransformCombined", func(b *testing.B) {
|
||||||
|
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42)
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
_, err := sampler.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Sampling failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkGreedySampler(b *testing.B) {
|
||||||
|
sizes := []int{10, 100, 1000, 10000, 100000}
|
||||||
|
|
||||||
|
for _, size := range sizes {
|
||||||
|
b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
|
||||||
|
logits := make([]float32, size)
|
||||||
|
for i := range logits {
|
||||||
|
logits[i] = float32(rand.Float64()*10 - 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
sampler := NewSampler(0, -1, 0, 0, -1)
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
_, err := sampler.Sample(logits)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Sampling failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -1,15 +1,14 @@
|
|||||||
package sample
|
package sample
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math"
|
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestWeighted(t *testing.T) {
|
func TestWeighted(t *testing.T) {
|
||||||
got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
|
logits := []float32{-10, 3, -10, -10}
|
||||||
|
sampler := NewSampler(0, 0, 0, 0, 0)
|
||||||
|
got, err := sampler.Sample(logits)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
@ -19,64 +18,19 @@ func TestWeighted(t *testing.T) {
|
|||||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
|
logits = []float32{-100, -10, 0, 10}
|
||||||
if err == nil {
|
sampler = NewSampler(0, 0, 0, 0, 0)
|
||||||
t.Error("expected error for no valid tokens, got index", got)
|
got, err = sampler.Sample(logits)
|
||||||
}
|
|
||||||
|
|
||||||
seed := uint64(42)
|
|
||||||
got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// With seed 42, we expect a consistent sample
|
want = int32(3) // Should pick highest probability with this r value
|
||||||
want = int32(3) // This will be deterministic due to the seed
|
|
||||||
if want != got {
|
if want != got {
|
||||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type testTransform struct {
|
|
||||||
id int
|
|
||||||
callOrder *[]int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ts *testTransform) Apply(logits []float64) []float64 {
|
|
||||||
if ts.callOrder != nil {
|
|
||||||
*ts.callOrder = append(*ts.callOrder, ts.id)
|
|
||||||
}
|
|
||||||
return logits
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSample(t *testing.T) {
|
|
||||||
input := []float32{1, 2, 3, 4}
|
|
||||||
|
|
||||||
var callOrder []int
|
|
||||||
mock1 := &testTransform{
|
|
||||||
id: 1,
|
|
||||||
callOrder: &callOrder,
|
|
||||||
}
|
|
||||||
mock2 := &testTransform{
|
|
||||||
id: 2,
|
|
||||||
callOrder: &callOrder,
|
|
||||||
}
|
|
||||||
mock3 := &testTransform{
|
|
||||||
id: 3,
|
|
||||||
callOrder: &callOrder,
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := Weighted(nil, mock1, mock2, mock3).Sample(input)
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
wantOrder := []int{1, 2, 3}
|
|
||||||
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
|
|
||||||
t.Errorf("call order mismatch (-want +got):\n%s", diff)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewSampler(t *testing.T) {
|
func TestNewSampler(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@ -85,75 +39,41 @@ func TestNewSampler(t *testing.T) {
|
|||||||
topP float32
|
topP float32
|
||||||
minP float32
|
minP float32
|
||||||
seed int
|
seed int
|
||||||
wantErr bool
|
wantGreedy bool // Instead of wantErr, check if we get greedy sampler
|
||||||
}{
|
}{
|
||||||
{
|
|
||||||
name: "no transforms",
|
|
||||||
// temperature is 0, so greedy should be used
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "temperature",
|
name: "temperature",
|
||||||
temperature: 0.5,
|
temperature: 0.5,
|
||||||
wantErr: false,
|
wantGreedy: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid temperature negative",
|
name: "zero temperature - greedy",
|
||||||
temperature: -1,
|
temperature: 0,
|
||||||
wantErr: true,
|
wantGreedy: true,
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid temperature too high",
|
|
||||||
temperature: 2.1,
|
|
||||||
wantErr: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "top k",
|
name: "top k",
|
||||||
|
temperature: 0.1,
|
||||||
topK: 10,
|
topK: 10,
|
||||||
temperature: 0.8,
|
wantGreedy: false,
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid top k negative",
|
|
||||||
topK: -1,
|
|
||||||
temperature: 0.8,
|
|
||||||
wantErr: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "top p",
|
name: "top p",
|
||||||
|
temperature: 0.1,
|
||||||
topP: 0.9,
|
topP: 0.9,
|
||||||
temperature: 0.8,
|
wantGreedy: false,
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid top p negative",
|
|
||||||
topP: -0.1,
|
|
||||||
temperature: 0.8,
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid top p one",
|
|
||||||
topP: 1.0,
|
|
||||||
temperature: 0.8,
|
|
||||||
wantErr: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "min p",
|
name: "min p",
|
||||||
|
temperature: 0.1,
|
||||||
minP: 0.2,
|
minP: 0.2,
|
||||||
temperature: 0.8,
|
wantGreedy: false,
|
||||||
wantErr: false,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid min p negative",
|
name: "seed - weighted",
|
||||||
minP: -0.1,
|
temperature: 0.1,
|
||||||
temperature: 0.8,
|
seed: 42,
|
||||||
wantErr: true,
|
wantGreedy: false,
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid min p one",
|
|
||||||
minP: 1.0,
|
|
||||||
temperature: 0.8,
|
|
||||||
wantErr: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "default values",
|
name: "default values",
|
||||||
@ -162,16 +82,16 @@ func TestNewSampler(t *testing.T) {
|
|||||||
topP: 0.9,
|
topP: 0.9,
|
||||||
minP: 0.0,
|
minP: 0.0,
|
||||||
seed: 0,
|
seed: 0,
|
||||||
wantErr: false,
|
wantGreedy: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "all zeroes",
|
name: "all zeroes - greedy",
|
||||||
temperature: 0.0,
|
temperature: 0.0,
|
||||||
topK: 0,
|
topK: 0,
|
||||||
topP: 0.0,
|
topP: 0.0,
|
||||||
minP: 0.0,
|
minP: 0.0,
|
||||||
seed: 0,
|
seed: 0,
|
||||||
wantErr: false, // all zeroes means no transforms
|
wantGreedy: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "all transforms",
|
name: "all transforms",
|
||||||
@ -180,33 +100,28 @@ func TestNewSampler(t *testing.T) {
|
|||||||
topP: 0.95,
|
topP: 0.95,
|
||||||
minP: 0.1,
|
minP: 0.1,
|
||||||
seed: 42,
|
seed: 42,
|
||||||
wantErr: false,
|
wantGreedy: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
_, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
|
sampler := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
|
||||||
if (err != nil) != tt.wantErr {
|
_, isGreedy := sampler.(*greedy)
|
||||||
t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr)
|
if isGreedy != tt.wantGreedy {
|
||||||
|
t.Errorf("NewSampler() got greedy = %v, want %v", isGreedy, tt.wantGreedy)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkSample(b *testing.B) {
|
func BenchmarkSample(b *testing.B) {
|
||||||
transforms := []Transform{
|
weighted := NewSampler(0.5, 10, 0.9, 0.2, -1)
|
||||||
Temperature(0.5),
|
|
||||||
TopK(10),
|
|
||||||
TopP(0.9),
|
|
||||||
MinP(0.2),
|
|
||||||
}
|
|
||||||
|
|
||||||
samplers := map[string]Sampler{
|
samplers := map[string]Sampler{
|
||||||
"Greedy": Greedy(),
|
"Greedy": NewSampler(0, 0, 0, 0, 0), // Use NewSampler with temp=0 for greedy
|
||||||
"Weighted": Weighted(nil, transforms...),
|
"Weighted": weighted,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generate random logits for benchmarking
|
||||||
logits := make([]float32, 1<<16)
|
logits := make([]float32, 1<<16)
|
||||||
for i := range logits {
|
for i := range logits {
|
||||||
logits[i] = rand.Float32()
|
logits[i] = rand.Float32()
|
||||||
@ -215,7 +130,7 @@ func BenchmarkSample(b *testing.B) {
|
|||||||
for name, s := range samplers {
|
for name, s := range samplers {
|
||||||
b.Run(name, func(b *testing.B) {
|
b.Run(name, func(b *testing.B) {
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for range b.N {
|
for b.Loop() {
|
||||||
if _, err := s.Sample(logits); err != nil {
|
if _, err := s.Sample(logits); err != nil {
|
||||||
b.Error(err)
|
b.Error(err)
|
||||||
}
|
}
|
||||||
|
@ -1,120 +1,203 @@
|
|||||||
package sample
|
package sample
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
|
||||||
"math"
|
"math"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
pq "github.com/emirpasic/gods/v2/queues/priorityqueue"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Transform interface {
|
func softmax(ts []logit) []logit {
|
||||||
Apply([]float64) []float64
|
var sum float32
|
||||||
}
|
for i, v := range ts {
|
||||||
|
ts[i].value = float32(math.Exp(float64(v.value)))
|
||||||
// TODO(parthsareen): potentially cache softmax values
|
sum += ts[i].value
|
||||||
func softmax(logits []float64) []float64 {
|
|
||||||
var sum float64
|
|
||||||
probs := make([]float64, len(logits))
|
|
||||||
for i, v := range logits {
|
|
||||||
probs[i] = math.Exp(v)
|
|
||||||
sum += probs[i]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range probs {
|
for i := range ts {
|
||||||
probs[i] /= sum
|
ts[i].value /= sum
|
||||||
}
|
}
|
||||||
|
|
||||||
return probs
|
return ts
|
||||||
}
|
}
|
||||||
|
|
||||||
type Temperature float64
|
func temperature(ti []logit, t float32) []logit {
|
||||||
|
if t == 1 {
|
||||||
|
return ti
|
||||||
|
}
|
||||||
|
|
||||||
func (t Temperature) Apply(logits []float64) []float64 {
|
temp := max(t, 1e-7)
|
||||||
temp := math.Max(float64(t), 1e-7)
|
maxLogit := float32(math.Inf(-1))
|
||||||
|
for _, token := range ti {
|
||||||
|
if token.value > maxLogit {
|
||||||
|
maxLogit = token.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// subtracting max logit to avoid under/overflow
|
// subtracting max logit to avoid under/overflow
|
||||||
maxLogit := slices.Max(logits)
|
for i := range ti {
|
||||||
for i := range logits {
|
ti[i].value = (ti[i].value - maxLogit) / temp
|
||||||
logits[i] = (logits[i] - maxLogit) / temp
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return logits
|
return ti
|
||||||
}
|
}
|
||||||
|
|
||||||
type logitMap struct {
|
// siftDown maintains a min-heap property by recursively moving larger elements down the heap.
|
||||||
index int
|
//
|
||||||
logit float64
|
// The heap is represented as an array where for any node at index i:
|
||||||
}
|
// - Left child is at index 2i + 1
|
||||||
|
// - Right child is at index 2i + 2
|
||||||
type TopK int
|
// - Parent is at index (i-1)/2
|
||||||
|
//
|
||||||
// TODO(parthsareen): avoid having to check all logits after this transform
|
// The function compares a node with its children and:
|
||||||
func (k TopK) Apply(logits []float64) []float64 {
|
// 1. Finds the smallest value between the node and its children
|
||||||
if int(k) >= len(logits) {
|
// 2. If the node is not the smallest, swaps it with its smallest child
|
||||||
return logits
|
// 3. Continues this process down the affected path until the min-heap property is restored
|
||||||
}
|
func siftDown(data []logit, start, end int) {
|
||||||
q := pq.NewWith(func(a, b logitMap) int {
|
root := start
|
||||||
return -cmp.Compare(a.logit, b.logit)
|
for {
|
||||||
})
|
child := 2*root + 1
|
||||||
|
if child >= end {
|
||||||
for i, logit := range logits {
|
|
||||||
q.Enqueue(logitMap{index: i, logit: logit})
|
|
||||||
}
|
|
||||||
|
|
||||||
validLogits := make(map[int]float64)
|
|
||||||
for range k {
|
|
||||||
logitMap, _ := q.Dequeue()
|
|
||||||
validLogits[logitMap.index] = logitMap.logit
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range logits {
|
|
||||||
if _, ok := validLogits[i]; !ok {
|
|
||||||
logits[i] = math.Inf(-1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return logits
|
|
||||||
}
|
|
||||||
|
|
||||||
type TopP float64
|
|
||||||
|
|
||||||
func (p TopP) Apply(logits []float64) []float64 {
|
|
||||||
probs := softmax(logits)
|
|
||||||
indices := make([]int, len(probs))
|
|
||||||
for i := range indices {
|
|
||||||
indices[i] = i
|
|
||||||
}
|
|
||||||
|
|
||||||
// sort in descending order
|
|
||||||
slices.SortFunc(indices, func(i, j int) int {
|
|
||||||
return cmp.Compare(probs[j], probs[i])
|
|
||||||
})
|
|
||||||
|
|
||||||
var sum float64
|
|
||||||
for i, idx := range indices {
|
|
||||||
sum += probs[idx]
|
|
||||||
if sum > float64(p) {
|
|
||||||
for _, idx := range indices[i+1:] {
|
|
||||||
logits[idx] = math.Inf(-1)
|
|
||||||
}
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
// Find smaller child (we want min heap)
|
||||||
|
if child+1 < end && data[child+1].value < data[child].value {
|
||||||
|
child++
|
||||||
|
}
|
||||||
|
// Exit if root is already smaller than children
|
||||||
|
if data[root].value <= data[child].value {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Swap with smaller child and continue
|
||||||
|
data[root], data[child] = data[child], data[root]
|
||||||
|
root = child
|
||||||
}
|
}
|
||||||
return logits
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type MinP float64
|
// topK limits the number of tokens considered to the k highest logits
|
||||||
|
func topK(ts []logit, k int) []logit {
|
||||||
|
if k >= len(ts) {
|
||||||
|
return ts
|
||||||
|
}
|
||||||
|
// Heapify + siftDown - O(nlog(k))
|
||||||
|
// Build min-heap of first k elements
|
||||||
|
heap := ts[:k]
|
||||||
|
for i := k/2 - 1; i >= 0; i-- {
|
||||||
|
siftDown(heap, i, k)
|
||||||
|
}
|
||||||
|
|
||||||
func (p MinP) Apply(logits []float64) []float64 {
|
// Process remaining elements - if larger than heap root, replace root
|
||||||
probs := softmax(logits)
|
for i := k; i < len(ts); i++ {
|
||||||
threshold := slices.Max(probs) * float64(p)
|
if ts[i].value > heap[0].value {
|
||||||
|
heap[0] = ts[i]
|
||||||
for i, prob := range probs {
|
siftDown(heap, 0, k)
|
||||||
if prob < threshold {
|
|
||||||
logits[i] = math.Inf(-1)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return logits
|
slices.Reverse(heap)
|
||||||
|
|
||||||
|
ts = heap
|
||||||
|
return ts
|
||||||
|
}
|
||||||
|
|
||||||
|
// topP limits tokens to those with cumulative probability p
|
||||||
|
func topP(ts []logit, p float32) []logit {
|
||||||
|
if p == 1.0 {
|
||||||
|
return ts
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find cutoff index where cumulative sum exceeds p
|
||||||
|
var sum float32
|
||||||
|
for i, t := range ts {
|
||||||
|
sum += t.value
|
||||||
|
if sum > float32(p) {
|
||||||
|
ts = ts[:i+1]
|
||||||
|
return ts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ts
|
||||||
|
}
|
||||||
|
|
||||||
|
// minP limits tokens to those with cumulative probability p
|
||||||
|
func minP(ts []logit, p float32) []logit {
|
||||||
|
if p == 1.0 {
|
||||||
|
return ts
|
||||||
|
}
|
||||||
|
|
||||||
|
maxProb := float32(math.Inf(-1))
|
||||||
|
for _, token := range ts {
|
||||||
|
if token.value > maxProb {
|
||||||
|
maxProb = token.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threshold := maxProb * float32(p)
|
||||||
|
|
||||||
|
// Filter tokens in-place
|
||||||
|
validTokens := ts[:0]
|
||||||
|
for i, token := range ts {
|
||||||
|
if token.value >= threshold {
|
||||||
|
validTokens = append(validTokens, ts[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ts = validTokens
|
||||||
|
return ts
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
|
||||||
|
// Conting sort implementation to sort tokens by logits
|
||||||
|
func sortLogits(tokens []logit) {
|
||||||
|
if len(tokens) <= 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find max/min in a single pass
|
||||||
|
minLogit, maxLogit := tokens[0].value, tokens[0].value
|
||||||
|
for _, t := range tokens[1:] {
|
||||||
|
if t.value < minLogit {
|
||||||
|
minLogit = t.value
|
||||||
|
} else if t.value > maxLogit {
|
||||||
|
maxLogit = t.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate scaling to map to uint32 range
|
||||||
|
logitRange := maxLogit - minLogit
|
||||||
|
if logitRange < 1e-6 {
|
||||||
|
return // All values effectively equal
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count frequencies directly from tokens
|
||||||
|
const maxInt = (1 << 24) - 1 // Use 24 bits for good granularity
|
||||||
|
var counts [256]int // For first byte
|
||||||
|
|
||||||
|
// First pass: count frequencies
|
||||||
|
for _, t := range tokens {
|
||||||
|
// Map to [0, maxInt] range
|
||||||
|
score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
|
||||||
|
counts[score>>16]++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate offsets
|
||||||
|
var offset int
|
||||||
|
for i := range counts {
|
||||||
|
count := counts[i]
|
||||||
|
counts[i] = offset
|
||||||
|
offset += count
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second pass: place elements in correct position
|
||||||
|
output := make([]logit, len(tokens))
|
||||||
|
// Track current positions
|
||||||
|
countsCopy := counts
|
||||||
|
|
||||||
|
for i, t := range tokens {
|
||||||
|
score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
|
||||||
|
|
||||||
|
pos := countsCopy[score>>16]
|
||||||
|
countsCopy[score>>16]++
|
||||||
|
output[len(tokens)-1-pos] = tokens[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(tokens, output)
|
||||||
}
|
}
|
||||||
|
@ -4,77 +4,182 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTemperature(t *testing.T) {
|
// Helper to convert float64 slice to logit slice
|
||||||
got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
|
func toLogits(values []float64) []logit {
|
||||||
want := []float64{-4, -10, 0, -14, -6, -12, -8}
|
tokens := make([]logit, len(values))
|
||||||
if diff := cmp.Diff(want, got); diff != "" {
|
for i, v := range values {
|
||||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
tokens[i] = logit{
|
||||||
|
id: int32(i),
|
||||||
|
value: float32(v),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to compare logit slices
|
||||||
|
func compareLogits(t *testing.T, name string, want []float64, got []logit) {
|
||||||
|
t.Helper()
|
||||||
|
if len(want) != len(got) {
|
||||||
|
t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if math.Abs(float64(got[i].value)-want[i]) > 1e-6 {
|
||||||
|
t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSoftmax(t *testing.T) {
|
func TestTemperature(t *testing.T) {
|
||||||
got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
|
input := []float64{2, -1, 4, -3, 1, -2, 0}
|
||||||
|
want := []float64{-4, -10, 0, -14, -6, -12, -8} // (logit - max logit) / temp
|
||||||
|
|
||||||
want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
|
got := temperature(toLogits(input), 0.5)
|
||||||
if diff := cmp.Diff(want, got); diff != "" {
|
compareLogits(t, "Temperature", want, got)
|
||||||
t.Errorf("probs mismatch (-want +got):\n%s", diff)
|
}
|
||||||
|
|
||||||
|
func TestSoftmax(t *testing.T) {
|
||||||
|
input := []float64{-3, -2, -1, 0, 1, 2, 4}
|
||||||
|
got := softmax(toLogits(input))
|
||||||
|
|
||||||
|
// Check probabilities sum to 1
|
||||||
|
var sum float32
|
||||||
|
for _, token := range got {
|
||||||
|
sum += token.value
|
||||||
|
}
|
||||||
|
if math.Abs(float64(sum)-1.0) > 1e-6 {
|
||||||
|
t.Errorf("probabilities don't sum to 1: got %f", sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check relative ordering is preserved
|
||||||
|
for i := 1; i < len(got); i++ {
|
||||||
|
if got[i].value < got[i-1].value {
|
||||||
|
t.Errorf("probability ordering not preserved at index %d", i)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTopK(t *testing.T) {
|
func TestTopK(t *testing.T) {
|
||||||
got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
input := []float64{-3, -2, -1, 0, 1, 2, 4}
|
||||||
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
|
|
||||||
if diff := cmp.Diff(want, got); diff != "" {
|
|
||||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
|
||||||
}
|
|
||||||
|
|
||||||
got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
// Test k=3
|
||||||
|
got := topK(toLogits(input), 3)
|
||||||
want = []float64{-3, -2, -1, 0, 1, 2, 4}
|
if len(got) != 3 {
|
||||||
if diff := cmp.Diff(want, got); diff != "" {
|
t.Errorf("topK(3): wrong length: want 3, got %d", len(got))
|
||||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
|
||||||
}
|
}
|
||||||
|
// Should keep highest 3 values: 4, 2, 1
|
||||||
|
want := []float64{4, 2, 1}
|
||||||
|
compareLogits(t, "topK(3)", want, got)
|
||||||
|
|
||||||
|
// Test k > len
|
||||||
|
got = topK(toLogits(input), 10)
|
||||||
|
compareLogits(t, "topK(10)", input, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTopP(t *testing.T) {
|
func TestTopP(t *testing.T) {
|
||||||
got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
input := []float64{-3, -2, -1, 0, 1, 2, 4}
|
||||||
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
|
tokens := toLogits(input)
|
||||||
if diff := cmp.Diff(want, got); diff != "" {
|
|
||||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
// First apply temperature and softmax to get probabilities
|
||||||
|
tokens = temperature(tokens, 1)
|
||||||
|
tokens = softmax(tokens)
|
||||||
|
sortLogits(tokens)
|
||||||
|
|
||||||
|
// Then apply topP
|
||||||
|
got := topP(tokens, 0.95)
|
||||||
|
|
||||||
|
// Should keep tokens until cumsum > 0.95
|
||||||
|
if len(got) > 3 {
|
||||||
|
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
|
||||||
|
t.Logf("got: %v", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMinP(t *testing.T) {
|
func TestMinP(t *testing.T) {
|
||||||
got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
|
input := []float64{-3, -2, -1, 0, 1, 2, 4, 3}
|
||||||
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
|
tokens := toLogits(input)
|
||||||
if diff := cmp.Diff(want, got); diff != "" {
|
|
||||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
// First apply temperature and softmax
|
||||||
|
tokens = temperature(tokens, 1)
|
||||||
|
tokens = softmax(tokens)
|
||||||
|
|
||||||
|
// Then apply minP
|
||||||
|
got := minP(tokens, 0.2)
|
||||||
|
|
||||||
|
// Should keep tokens with prob >= 0.2 * max_prob
|
||||||
|
if len(got) > 3 {
|
||||||
|
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkTransform(b *testing.B) {
|
func TestSortLogits(t *testing.T) {
|
||||||
transforms := map[string]Transform{
|
input := []float64{3, 1, 4, 2, -1, 0, -2}
|
||||||
"Temperature": Temperature(0.5),
|
tokens := toLogits(input)
|
||||||
"TopK": TopK(10),
|
|
||||||
"TopP": TopP(0.9),
|
sortLogits(tokens)
|
||||||
"MinP": MinP(0.2),
|
|
||||||
|
for i := 1; i < len(tokens); i++ {
|
||||||
|
if tokens[i].value > tokens[i-1].value {
|
||||||
|
t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
|
||||||
|
i, tokens[i].value, tokens[i-1].value)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logits := make([]float64, 1<<16)
|
want := []float64{4, 3, 2, 1, 0, -1, -2}
|
||||||
for i := range logits {
|
compareLogits(t, "sortLogits", want, tokens)
|
||||||
logits[i] = rand.Float64()
|
}
|
||||||
}
|
|
||||||
|
func BenchmarkTransforms(b *testing.B) {
|
||||||
for name, transform := range transforms {
|
// Generate random logits
|
||||||
b.Run(name, func(b *testing.B) {
|
tokens := make([]logit, 1<<16)
|
||||||
b.ResetTimer()
|
for i := range tokens {
|
||||||
for range b.N {
|
tokens[i] = logit{
|
||||||
transform.Apply(logits)
|
id: int32(i),
|
||||||
}
|
value: rand.Float32(),
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tokensCopy := make([]logit, len(tokens))
|
||||||
|
|
||||||
|
b.Run("Temperature", func(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
copy(tokensCopy, tokens)
|
||||||
|
temperature(tokensCopy, 0.5)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("TopK", func(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
copy(tokensCopy, tokens)
|
||||||
|
topK(tokensCopy, 10)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("TopP", func(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
copy(tokensCopy, tokens)
|
||||||
|
topP(tokensCopy, 0.9)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("MinP", func(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
copy(tokensCopy, tokens)
|
||||||
|
minP(tokensCopy, 0.2)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("SortTokens", func(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
for b.Loop() {
|
||||||
|
copy(tokensCopy, tokens)
|
||||||
|
sortLogits(tokensCopy)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user