This commit is contained in:
ParthSareen 2025-03-25 16:45:27 -07:00
parent 5ec6bb52a0
commit 4450f871db
2 changed files with 25 additions and 12 deletions

View File

@ -582,16 +582,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
// return
// }
// jsonSampler = nil
// pythonSampler := sample.NewPythonSampler(s.model.(model.TextProcessor), nil)
// pythonSampler := &sample.PythonSampler{}
// functions := []sample.PythonFunction{
// {
// Name: "add_two_strings",
// Args: []string{"s1", "s2"},
// Types: []string{"string", "string"},
// },
// }
// pythonSampler.Init(functions, s.model.(model.TextProcessor))
pythonSampler := &sample.PythonSampler{}
functions := []sample.PythonFunction{
{
Name: "add_two_strings",
Args: []string{"s1", "s2"},
Types: []string{"string", "string"},
},
}
pythonSampler.Init(functions, s.model.(model.TextProcessor))
sampler := sample.NewSampler(
req.Options.Temperature,
req.Options.TopK,
@ -600,7 +599,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
req.Options.Seed,
grammar,
nil,
nil,
pythonSampler,
// nil,
)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{

View File

@ -108,10 +108,13 @@ type PythonSampler struct {
proc model.TextProcessor
decodedToks []string
curNode *Node
completed int
functions []PythonFunction
}
func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
s.proc = proc
s.functions = functions
decodedToks := make([]string, len(proc.Vocab().Values))
for i := range proc.Vocab().Values {
token, err := proc.Decode([]int32{int32(i)})
@ -194,7 +197,7 @@ func (s *PythonSampler) BuildGraph() error {
// String end
s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
// s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
// Number
for _, r := range validNumberRunes {
s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
@ -237,6 +240,16 @@ func (s *PythonSampler) UpdateState(token int32) error {
if !ok {
return fmt.Errorf("invalid token: %q", mappedString)
}
if mappedString == "\"" {
if s.curNode.State == PStateInStringEnd {
s.completed++
}
if s.completed == len(s.functions) {
s.curNode.TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
s.CreateMask(s.curNode)
}
}
s.curNode = nextNode
fmt.Println("curNode", s.curNode.State)
for r, node := range s.curNode.TransitionEdges {