diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 3b85bc32d..cdc5100b6 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -582,16 +582,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { // return // } // jsonSampler = nil - // pythonSampler := sample.NewPythonSampler(s.model.(model.TextProcessor), nil) - // pythonSampler := &sample.PythonSampler{} - // functions := []sample.PythonFunction{ - // { - // Name: "add_two_strings", - // Args: []string{"s1", "s2"}, - // Types: []string{"string", "string"}, - // }, - // } - // pythonSampler.Init(functions, s.model.(model.TextProcessor)) + pythonSampler := &sample.PythonSampler{} + functions := []sample.PythonFunction{ + { + Name: "add_two_strings", + Args: []string{"s1", "s2"}, + Types: []string{"string", "string"}, + }, + } + pythonSampler.Init(functions, s.model.(model.TextProcessor)) sampler := sample.NewSampler( req.Options.Temperature, req.Options.TopK, @@ -600,7 +599,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { req.Options.Seed, grammar, nil, - nil, + pythonSampler, + // nil, ) seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ diff --git a/sample/structured_python.go b/sample/structured_python.go index 2b8de21c3..e8b4f5473 100644 --- a/sample/structured_python.go +++ b/sample/structured_python.go @@ -108,10 +108,13 @@ type PythonSampler struct { proc model.TextProcessor decodedToks []string curNode *Node + completed int + functions []PythonFunction } func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error { s.proc = proc + s.functions = functions decodedToks := make([]string, len(proc.Vocab().Values)) for i := range proc.Vocab().Values { token, err := proc.Decode([]int32{int32(i)}) @@ -194,7 +197,7 @@ func (s *PythonSampler) BuildGraph() error { // String end s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs] - s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate] + // s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate] // Number for _, r := range validNumberRunes { s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber] @@ -237,6 +240,16 @@ func (s *PythonSampler) UpdateState(token int32) error { if !ok { return fmt.Errorf("invalid token: %q", mappedString) } + + if mappedString == "\"" { + if s.curNode.State == PStateInStringEnd { + s.completed++ + } + if s.completed == len(s.functions) { + s.curNode.TransitionEdges[')'] = s.stateToNodes[PStateTerminate] + s.CreateMask(s.curNode) + } + } s.curNode = nextNode fmt.Println("curNode", s.curNode.State) for r, node := range s.curNode.TransitionEdges {