mirror of
https://github.com/ollama/ollama.git
synced 2025-04-07 11:28:17 +02:00
Enable array type json
This commit is contained in:
parent
77f709ebd5
commit
198fde82aa
@ -29,7 +29,6 @@ const (
|
||||
StateInObjSpace
|
||||
StateInList
|
||||
StateInListComma
|
||||
StateListEnd
|
||||
StateInValue
|
||||
StateInValueEnd
|
||||
StateInListEnd
|
||||
@ -63,7 +62,6 @@ var JSONStates = []JSONState{
|
||||
StateInObjSpace,
|
||||
StateInList,
|
||||
StateInListComma,
|
||||
StateListEnd,
|
||||
StateInValue,
|
||||
StateInValueEnd,
|
||||
StateInListEnd,
|
||||
@ -118,8 +116,6 @@ func (s JSONState) String() string {
|
||||
return "StateInListObjectEnd"
|
||||
case StateInListComma:
|
||||
return "StateInListComma"
|
||||
case StateListEnd:
|
||||
return "StateListEnd"
|
||||
case StateInListEnd:
|
||||
return "StateInListEnd"
|
||||
case StateInNewline:
|
||||
|
@ -47,6 +47,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
// Connect nodes
|
||||
// TODO: if all are single tokens then this can just be connected instead of defining the token
|
||||
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||
|
||||
stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||
@ -121,7 +122,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
|
||||
// list object end
|
||||
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
|
||||
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateListEnd]
|
||||
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
|
||||
// bool node
|
||||
for _, r := range validBoolRunes {
|
||||
@ -129,8 +130,8 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
}
|
||||
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
|
||||
|
||||
stateToNodeMap[StateListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
stateToNodeMap[StateListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||
|
||||
stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
||||
@ -147,7 +148,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
||||
node.TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
node.TransitionEdges[']'] = stateToNodeMap[StateListEnd]
|
||||
node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
}
|
||||
|
||||
func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
||||
|
@ -58,6 +58,27 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
case StateInString:
|
||||
return s.maskLogits(logits, s.curNode)
|
||||
|
||||
case StateInListEnd:
|
||||
fmt.Println("in list end", s.braceStack)
|
||||
// force finish if no braces left
|
||||
if len(s.braceStack) == 0 {
|
||||
s.curNode = NewPDANode(StateTerminate)
|
||||
for i := range logits {
|
||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||
logits[i] = 1.0
|
||||
} else {
|
||||
logits[i] = math.NaN()
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
case StateInObjectEnd:
|
||||
// force finish if no braces left
|
||||
if len(s.braceStack) == 0 {
|
||||
@ -117,11 +138,12 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
}
|
||||
|
||||
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
// fmt.Println("update state", s.curNode.State)
|
||||
fmt.Println("update state", s.curNode.State)
|
||||
mappedString, err := s.proc.Decode(tokenSlice)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("mappedString", mappedString)
|
||||
|
||||
// TODO: should force closing for all braces - not doing square yet
|
||||
for _, r := range mappedString {
|
||||
|
Loading…
x
Reference in New Issue
Block a user