From 198fde82aa581a011d1f824b1df931ab5cc40878 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Wed, 29 Jan 2025 14:54:39 -0800 Subject: [PATCH] Enable array type json --- sample/fast_json.go | 4 ---- sample/pushdown_automata.go | 9 +++++---- sample/pushdown_runner.go | 24 +++++++++++++++++++++++- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/sample/fast_json.go b/sample/fast_json.go index adcc26bb7..486490f21 100644 --- a/sample/fast_json.go +++ b/sample/fast_json.go @@ -29,7 +29,6 @@ const ( StateInObjSpace StateInList StateInListComma - StateListEnd StateInValue StateInValueEnd StateInListEnd @@ -63,7 +62,6 @@ var JSONStates = []JSONState{ StateInObjSpace, StateInList, StateInListComma, - StateListEnd, StateInValue, StateInValueEnd, StateInListEnd, @@ -118,8 +116,6 @@ func (s JSONState) String() string { return "StateInListObjectEnd" case StateInListComma: return "StateInListComma" - case StateListEnd: - return "StateListEnd" case StateInListEnd: return "StateInListEnd" case StateInNewline: diff --git a/sample/pushdown_automata.go b/sample/pushdown_automata.go index b8f83a8bd..7bdecf38d 100644 --- a/sample/pushdown_automata.go +++ b/sample/pushdown_automata.go @@ -47,6 +47,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err // Connect nodes // TODO: if all are single tokens then this can just be connected instead of defining the token stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject] + stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList] stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] @@ -121,7 +122,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err // list object end stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma] - stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateListEnd] + stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] // bool node for _, r := range validBoolRunes { @@ -129,8 +130,8 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err } addEnds(stateToNodeMap[StateInBool], stateToNodeMap) - stateToNodeMap[StateListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] - stateToNodeMap[StateListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma] + stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] + stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma] stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInList] @@ -147,7 +148,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) { node.TransitionEdges[','] = stateToNodeMap[StateInComma] node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] - node.TransitionEdges[']'] = stateToNodeMap[StateListEnd] + node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd] } func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) { diff --git a/sample/pushdown_runner.go b/sample/pushdown_runner.go index cb467e81c..e1d2cce65 100644 --- a/sample/pushdown_runner.go +++ b/sample/pushdown_runner.go @@ -58,6 +58,27 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { case StateInString: return s.maskLogits(logits, s.curNode) + case StateInListEnd: + fmt.Println("in list end", s.braceStack) + // force finish if no braces left + if len(s.braceStack) == 0 { + s.curNode = NewPDANode(StateTerminate) + for i := range logits { + if s.proc.Is(uint32(i), model.SpecialEOS) { + logits[i] = 1.0 + } else { + logits[i] = math.NaN() + } + } + return logits, nil + } + + logits, err := s.maskLogits(logits, s.curNode) + if err != nil { + return nil, err + } + return logits, nil + case StateInObjectEnd: // force finish if no braces left if len(s.braceStack) == 0 { @@ -117,11 +138,12 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { } func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { - // fmt.Println("update state", s.curNode.State) + fmt.Println("update state", s.curNode.State) mappedString, err := s.proc.Decode(tokenSlice) if err != nil { return err } + fmt.Println("mappedString", mappedString) // TODO: should force closing for all braces - not doing square yet for _, r := range mappedString {