mirror of
https://github.com/ollama/ollama.git
synced 2025-04-12 21:59:22 +02:00
wip
This commit is contained in:
parent
5ec6bb52a0
commit
4450f871db
@ -582,16 +582,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
// return
|
||||
// }
|
||||
// jsonSampler = nil
|
||||
// pythonSampler := sample.NewPythonSampler(s.model.(model.TextProcessor), nil)
|
||||
// pythonSampler := &sample.PythonSampler{}
|
||||
// functions := []sample.PythonFunction{
|
||||
// {
|
||||
// Name: "add_two_strings",
|
||||
// Args: []string{"s1", "s2"},
|
||||
// Types: []string{"string", "string"},
|
||||
// },
|
||||
// }
|
||||
// pythonSampler.Init(functions, s.model.(model.TextProcessor))
|
||||
pythonSampler := &sample.PythonSampler{}
|
||||
functions := []sample.PythonFunction{
|
||||
{
|
||||
Name: "add_two_strings",
|
||||
Args: []string{"s1", "s2"},
|
||||
Types: []string{"string", "string"},
|
||||
},
|
||||
}
|
||||
pythonSampler.Init(functions, s.model.(model.TextProcessor))
|
||||
sampler := sample.NewSampler(
|
||||
req.Options.Temperature,
|
||||
req.Options.TopK,
|
||||
@ -600,7 +599,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
req.Options.Seed,
|
||||
grammar,
|
||||
nil,
|
||||
nil,
|
||||
pythonSampler,
|
||||
// nil,
|
||||
)
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
|
@ -108,10 +108,13 @@ type PythonSampler struct {
|
||||
proc model.TextProcessor
|
||||
decodedToks []string
|
||||
curNode *Node
|
||||
completed int
|
||||
functions []PythonFunction
|
||||
}
|
||||
|
||||
func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
|
||||
s.proc = proc
|
||||
s.functions = functions
|
||||
decodedToks := make([]string, len(proc.Vocab().Values))
|
||||
for i := range proc.Vocab().Values {
|
||||
token, err := proc.Decode([]int32{int32(i)})
|
||||
@ -194,7 +197,7 @@ func (s *PythonSampler) BuildGraph() error {
|
||||
|
||||
// String end
|
||||
s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
|
||||
s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
||||
// s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
||||
// Number
|
||||
for _, r := range validNumberRunes {
|
||||
s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
|
||||
@ -237,6 +240,16 @@ func (s *PythonSampler) UpdateState(token int32) error {
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid token: %q", mappedString)
|
||||
}
|
||||
|
||||
if mappedString == "\"" {
|
||||
if s.curNode.State == PStateInStringEnd {
|
||||
s.completed++
|
||||
}
|
||||
if s.completed == len(s.functions) {
|
||||
s.curNode.TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
||||
s.CreateMask(s.curNode)
|
||||
}
|
||||
}
|
||||
s.curNode = nextNode
|
||||
fmt.Println("curNode", s.curNode.State)
|
||||
for r, node := range s.curNode.TransitionEdges {
|
||||
|
Loading…
x
Reference in New Issue
Block a user