mirror of
https://github.com/ollama/ollama.git
synced 2025-03-24 00:31:55 +01:00
75 lines
1.2 KiB
Go
75 lines
1.2 KiB
Go
|
package sample
|
||
|
|
||
|
import (
|
||
|
"slices"
|
||
|
|
||
|
"gonum.org/v1/gonum/floats"
|
||
|
"gonum.org/v1/gonum/stat/sampleuv"
|
||
|
)
|
||
|
|
||
|
type Sampler interface {
|
||
|
Sample([]float64) ([]float64, error)
|
||
|
}
|
||
|
|
||
|
type Temperature float64
|
||
|
|
||
|
func (s Temperature) Sample(t []float64) ([]float64, error) {
|
||
|
floats.Div(t, slices.Repeat([]float64{float64(s)}, len(t)))
|
||
|
return t, nil
|
||
|
}
|
||
|
|
||
|
type softmax struct{}
|
||
|
|
||
|
func Softmax() Sampler {
|
||
|
return softmax{}
|
||
|
}
|
||
|
|
||
|
func (softmax) Sample(t []float64) ([]float64, error) {
|
||
|
return t, nil
|
||
|
}
|
||
|
|
||
|
type TopK int
|
||
|
|
||
|
func (s TopK) Sample(t []float64) ([]float64, error) {
|
||
|
return t, nil
|
||
|
}
|
||
|
|
||
|
type TopP float32
|
||
|
|
||
|
func (s TopP) Sample(t []float64) ([]float64, error) {
|
||
|
return t, nil
|
||
|
}
|
||
|
|
||
|
type MinP float32
|
||
|
|
||
|
func (s MinP) Sample(t []float64) ([]float64, error) {
|
||
|
return t, nil
|
||
|
}
|
||
|
|
||
|
type weighed struct{}
|
||
|
|
||
|
func Weighed() Sampler {
|
||
|
return weighed{}
|
||
|
}
|
||
|
|
||
|
func (s weighed) Sample(t []float64) ([]float64, error) {
|
||
|
w := sampleuv.NewWeighted(t, nil)
|
||
|
if v, ok := w.Take(); ok {
|
||
|
return []float64{float64(v)}, nil
|
||
|
}
|
||
|
|
||
|
return t, nil
|
||
|
}
|
||
|
|
||
|
func Sample(floats []float64, samplers ...Sampler) ([]float64, error) {
|
||
|
var err error
|
||
|
for _, sampler := range samplers {
|
||
|
floats, err = sampler.Sample(floats)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return floats, nil
|
||
|
}
|