Enable array type json

This commit is contained in:
ParthSareen 2025-01-29 14:54:39 -08:00
parent 77f709ebd5
commit 198fde82aa
3 changed files with 28 additions and 9 deletions

View File

@ -29,7 +29,6 @@ const (
StateInObjSpace
StateInList
StateInListComma
StateListEnd
StateInValue
StateInValueEnd
StateInListEnd
@ -63,7 +62,6 @@ var JSONStates = []JSONState{
StateInObjSpace,
StateInList,
StateInListComma,
StateListEnd,
StateInValue,
StateInValueEnd,
StateInListEnd,
@ -118,8 +116,6 @@ func (s JSONState) String() string {
return "StateInListObjectEnd"
case StateInListComma:
return "StateInListComma"
case StateListEnd:
return "StateListEnd"
case StateInListEnd:
return "StateInListEnd"
case StateInNewline:

View File

@ -47,6 +47,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
// Connect nodes
// TODO: if all are single tokens then this can just be connected instead of defining the token
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
@ -121,7 +122,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
// list object end
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateListEnd]
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
// bool node
for _, r := range validBoolRunes {
@ -129,8 +130,8 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
}
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
stateToNodeMap[StateListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
@ -147,7 +148,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
node.TransitionEdges[','] = stateToNodeMap[StateInComma]
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
node.TransitionEdges[']'] = stateToNodeMap[StateListEnd]
node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
}
func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {

View File

@ -58,6 +58,27 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
case StateInString:
return s.maskLogits(logits, s.curNode)
case StateInListEnd:
fmt.Println("in list end", s.braceStack)
// force finish if no braces left
if len(s.braceStack) == 0 {
s.curNode = NewPDANode(StateTerminate)
for i := range logits {
if s.proc.Is(uint32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = math.NaN()
}
}
return logits, nil
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
case StateInObjectEnd:
// force finish if no braces left
if len(s.braceStack) == 0 {
@ -117,11 +138,12 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
}
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
// fmt.Println("update state", s.curNode.State)
fmt.Println("update state", s.curNode.State)
mappedString, err := s.proc.Decode(tokenSlice)
if err != nil {
return err
}
fmt.Println("mappedString", mappedString)
// TODO: should force closing for all braces - not doing square yet
for _, r := range mappedString {