package sample

import (
	"slices"

	"gonum.org/v1/gonum/floats"
	"gonum.org/v1/gonum/stat/sampleuv"
)

type Sampler interface {
	Sample([]float64) ([]float64, error)
}

type Temperature float64

func (s Temperature) Sample(t []float64) ([]float64, error) {
	floats.Div(t, slices.Repeat([]float64{float64(s)}, len(t)))
	return t, nil
}

type softmax struct{}

func Softmax() Sampler {
	return softmax{}
}

func (softmax) Sample(t []float64) ([]float64, error) {
	return t, nil
}

type TopK int

func (s TopK) Sample(t []float64) ([]float64, error) {
	return t, nil
}

type TopP float32

func (s TopP) Sample(t []float64) ([]float64, error) {
	return t, nil
}

type MinP float32

func (s MinP) Sample(t []float64) ([]float64, error) {
	return t, nil
}

type weighed struct{}

func Weighed() Sampler {
	return weighed{}
}

func (s weighed) Sample(t []float64) ([]float64, error) {
	w := sampleuv.NewWeighted(t, nil)
	if v, ok := w.Take(); ok {
		return []float64{float64(v)}, nil
	}

	return t, nil
}

func Sample(floats []float64, samplers ...Sampler) ([]float64, error) {
	var err error
	for _, sampler := range samplers {
		floats, err = sampler.Sample(floats)
		if err != nil {
			return nil, err
		}
	}

	return floats, nil
}