From 9039c821a2c572e8bd0ee5cde13e4cb55c332e35 Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Wed, 11 Dec 2024 14:07:30 -0800 Subject: [PATCH] llama: preserve field order in user-defined JSON schemas (#8002) Previously we decoded and re-encoded JSON schemas during validation, which served no purpose since json.RawMessage already validates JSON syntax. Worse, the re-encoding lost field ordering from the original schema, which affects inference quality during step-by-step reasoning. While fixing this ordering issue by using json.RawMessage directly, testing revealed that schema_to_grammar (from llama.cpp) also fails to preserve field order during grammar generation. This appears to be the root cause of inference degradation. This change prevents us from mangling the user's original schema order, but we still need to address the ordering issue in schema_to_grammar. That will be a separate change. Updates #7978 --- llama/grammar_test.go | 80 +++++++++++++++++++++++++++++++++++++++++++ llama/llama.go | 32 +++++------------ llama/llama_test.go | 69 ------------------------------------- llm/server.go | 29 +++++++--------- openai/openai.go | 8 ++--- 5 files changed, 104 insertions(+), 114 deletions(-) create mode 100644 llama/grammar_test.go diff --git a/llama/grammar_test.go b/llama/grammar_test.go new file mode 100644 index 000000000..b64976ae8 --- /dev/null +++ b/llama/grammar_test.go @@ -0,0 +1,80 @@ +package llama + +import ( + "bufio" + "bytes" + "strings" + "testing" +) + +// https://github.com/ollama/ollama/issues/7978 +const issue7978JSONSchema = `{ + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "explanation": { "type": "string" }, + "output": { "type": "string" } + }, + "required": ["explanation", "output"], + "additionalProperties": false + } + }, + "final_answer": { "type": "string" } + }, + "required": ["steps", "final_answer"], + "additionalProperties": false +}` + +func TestIssue7978(t *testing.T) { + t.Skip("schema_to_grammar is broken; skipping until fixed") + + g := SchemaToGrammar([]byte(issue7978JSONSchema)) + if g == nil { + t.Fatal("failed to convert JSON schema to grammar") + } + + t.Logf("grammar:\n%s", g) + t.Log() + + var sawSteps bool + s := bufio.NewScanner(bytes.NewReader(g)) + for s.Scan() { + line := s.Text() + if strings.Contains(line, "steps") { + sawSteps = true + } + if strings.Contains(line, "final-answer") && !sawSteps { + t.Error("expected 'steps' before 'final-answer'") + } + } +} + +func TestSchemaToGrammer(t *testing.T) { + t.Skip("schema_to_grammar is broken; skipping until fixed") + + cases := []struct { + schema string + prefix []byte // nil is check as nil + }{ + {`invalid`, nil}, + + // Simple heuristic/smoke test + {`{"type":"object"}`, []byte("object ::=")}, + } + + for _, c := range cases { + t.Run("x", func(t *testing.T) { + g := SchemaToGrammar([]byte(c.schema)) + if c.prefix == nil && g != nil { + t.Fatalf("grammar = %v, want nil", g) + } + if !bytes.HasPrefix(g, c.prefix) { + t.Errorf("grammar = %q, want %q", g, c.prefix) + } + }) + } +} diff --git a/llama/llama.go b/llama/llama.go index b770ca564..15e719798 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -86,12 +86,9 @@ COMPILER inline get_compiler() { import "C" import ( - "bytes" _ "embed" - "encoding/json" "errors" "fmt" - "log/slog" "runtime" "runtime/cgo" "slices" @@ -721,21 +718,10 @@ func (s *SamplingContext) Accept(id int, applyGrammar bool) { C.common_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar)) } -type JsonSchema struct { - Defs map[string]any `json:"$defs,omitempty"` - Properties map[string]any `json:"properties,omitempty"` - Required []string `json:"required,omitempty"` - Title string `json:"title,omitempty"` - Type string `json:"type,omitempty"` -} - -func (js JsonSchema) AsGrammar() string { - var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(js); err != nil { - return "" - } - - cStr := C.CString(b.String()) +// SchemaToGrammar converts the provided JSON schema to a grammar. It returns +// nil if the provided schema is invalid JSON or an invalid JSON schema. +func SchemaToGrammar(schema []byte) []byte { + cStr := C.CString(string(schema)) defer C.free(unsafe.Pointer(cStr)) // Allocate buffer for grammar output with reasonable size @@ -743,10 +729,10 @@ func (js JsonSchema) AsGrammar() string { buf := make([]byte, maxLen) // Call C function to convert schema to grammar - length := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen)) - if length == 0 { - slog.Warn("unable to convert schema to grammar") + n := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen)) + if n == 0 { + // preserve nil + return nil } - - return string(buf[:length]) + return buf[:n] } diff --git a/llama/llama_test.go b/llama/llama_test.go index 4fab133d2..5f835d683 100644 --- a/llama/llama_test.go +++ b/llama/llama_test.go @@ -1,70 +1 @@ package llama - -import ( - "strings" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestJsonSchema(t *testing.T) { - testCases := []struct { - name string - schema JsonSchema - expected string - }{ - { - name: "empty schema", - schema: JsonSchema{ - Type: "object", - }, - expected: `array ::= "[" space ( value ("," space value)* )? "]" space -boolean ::= ("true" | "false") space -char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4}) -decimal-part ::= [0-9]{1,16} -integral-part ::= [0] | [1-9] [0-9]{0,15} -null ::= "null" space -number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space -object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space -root ::= object -space ::= | " " | "\n" [ \t]{0,20} -string ::= "\"" char* "\"" space -value ::= object | array | string | number | boolean | null`, - }, - { - name: "invalid schema with circular reference", - schema: JsonSchema{ - Type: "object", - Properties: map[string]any{ - "self": map[string]any{ - "$ref": "#", // Self reference - }, - }, - }, - expected: "", // Should return empty string for invalid schema - }, - { - name: "schema with invalid type", - schema: JsonSchema{ - Type: "invalid_type", // Invalid type - Properties: map[string]any{ - "foo": map[string]any{ - "type": "string", - }, - }, - }, - expected: "", // Should return empty string for invalid schema - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - result := tc.schema.AsGrammar() - if !strings.EqualFold(strings.TrimSpace(result), strings.TrimSpace(tc.expected)) { - if diff := cmp.Diff(tc.expected, result); diff != "" { - t.Fatalf("grammar mismatch (-want +got):\n%s", diff) - } - } - }) - } -} diff --git a/llm/server.go b/llm/server.go index 37c204678..dc016ccb8 100644 --- a/llm/server.go +++ b/llm/server.go @@ -610,7 +610,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { } } -const jsonGrammar = ` +var grammarJSON = ` root ::= object value ::= object | array | string | number | ("true" | "false" | "null") ws object ::= @@ -722,22 +722,19 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return fmt.Errorf("unexpected server status: %s", status.ToString()) } - // TODO (parthsareen): Move conversion to grammar with sampling logic - // API should do error handling for invalid formats - if req.Format != nil && strings.TrimSpace(string(req.Format)) != "null" { - if strings.ToLower(strings.TrimSpace(string(req.Format))) == `"json"` { - request["grammar"] = jsonGrammar - if !strings.Contains(strings.ToLower(req.Prompt), "json") { - slog.Warn("prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.") + if len(req.Format) > 0 { + switch { + case bytes.Equal(req.Format, []byte(`"json"`)): + request["grammar"] = grammarJSON + case bytes.HasPrefix(req.Format, []byte("{")): + // User provided a JSON schema + g := llama.SchemaToGrammar(req.Format) + if g == nil { + return fmt.Errorf("invalid JSON schema in format") } - } else if schema, err := func() (llama.JsonSchema, error) { - var schema llama.JsonSchema - err := json.Unmarshal(req.Format, &schema) - return schema, err - }(); err == nil { - request["grammar"] = schema.AsGrammar() - } else { - slog.Warn(`format is neither a schema or "json"`, "format", req.Format) + request["grammar"] = string(g) + default: + return errors.New(`invalid format: expected "json" or a JSON schema`) } } diff --git a/openai/openai.go b/openai/openai.go index 3a35d9dda..6b28eee42 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -67,7 +67,7 @@ type ResponseFormat struct { } type JsonSchema struct { - Schema map[string]any `json:"schema"` + Schema json.RawMessage `json:"schema"` } type EmbedRequest struct { @@ -495,11 +495,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { format = json.RawMessage(`"json"`) case "json_schema": if r.ResponseFormat.JsonSchema != nil { - schema, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema) - if err != nil { - return nil, fmt.Errorf("failed to marshal json schema: %w", err) - } - format = schema + format = r.ResponseFormat.JsonSchema.Schema } } }