From 630e7dc6ff461cc957a1314d8f27986f0d7b92ca Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Wed, 4 Dec 2024 16:31:19 -0800 Subject: [PATCH] api: structured outputs - chat endpoint (#7900) Adds structured outputs to chat endpoint --------- Co-authored-by: Michael Yang Co-authored-by: Hieu Nguyen --- api/types.go | 2 +- cmd/cmd.go | 3 +- llama/llama.go | 33 ++++++++++++++++++++ llama/llama_test.go | 69 ++++++++++++++++++++++++++++++++++++++++++ llama/sampling_ext.cpp | 29 ++++++++++++++++-- llama/sampling_ext.h | 2 ++ llm/server.go | 27 +++++++++++------ openai/openai.go | 25 ++++++++++++--- openai/openai_test.go | 13 ++++---- server/routes.go | 2 +- 10 files changed, 180 insertions(+), 25 deletions(-) diff --git a/api/types.go b/api/types.go index d2108f884..a04fd9943 100644 --- a/api/types.go +++ b/api/types.go @@ -94,7 +94,7 @@ type ChatRequest struct { Stream *bool `json:"stream,omitempty"` // Format is the format to return the response in (e.g. "json"). - Format string `json:"format"` + Format json.RawMessage `json:"format,omitempty"` // KeepAlive controls how long the model will stay loaded into memory // following the request. diff --git a/cmd/cmd.go b/cmd/cmd.go index b863264f5..8b2031316 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -8,6 +8,7 @@ import ( "crypto/ed25519" "crypto/rand" "crypto/sha256" + "encoding/json" "encoding/pem" "errors" "fmt" @@ -1038,7 +1039,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { req := &api.ChatRequest{ Model: opts.Model, Messages: opts.Messages, - Format: opts.Format, + Format: json.RawMessage(opts.Format), Options: opts.Options, } diff --git a/llama/llama.go b/llama/llama.go index 24fa75274..97b58663c 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -85,9 +85,12 @@ COMPILER inline get_compiler() { import "C" import ( + "bytes" _ "embed" + "encoding/json" "errors" "fmt" + "log/slog" "runtime" "runtime/cgo" "slices" @@ -699,3 +702,33 @@ func (s *SamplingContext) Sample(llamaContext *Context, idx int) int { func (s *SamplingContext) Accept(id int, applyGrammar bool) { C.gpt_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar)) } + +type JsonSchema struct { + Defs map[string]any `json:"$defs,omitempty"` + Properties map[string]any `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` + Title string `json:"title,omitempty"` + Type string `json:"type,omitempty"` +} + +func (js JsonSchema) AsGrammar() string { + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(js); err != nil { + return "" + } + + cStr := C.CString(b.String()) + defer C.free(unsafe.Pointer(cStr)) + + // Allocate buffer for grammar output with reasonable size + const maxLen = 32768 // 32KB + buf := make([]byte, maxLen) + + // Call C function to convert schema to grammar + length := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen)) + if length == 0 { + slog.Warn("unable to convert schema to grammar") + } + + return string(buf[:length]) +} diff --git a/llama/llama_test.go b/llama/llama_test.go index 5f835d683..4fab133d2 100644 --- a/llama/llama_test.go +++ b/llama/llama_test.go @@ -1 +1,70 @@ package llama + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestJsonSchema(t *testing.T) { + testCases := []struct { + name string + schema JsonSchema + expected string + }{ + { + name: "empty schema", + schema: JsonSchema{ + Type: "object", + }, + expected: `array ::= "[" space ( value ("," space value)* )? "]" space +boolean ::= ("true" | "false") space +char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4}) +decimal-part ::= [0-9]{1,16} +integral-part ::= [0] | [1-9] [0-9]{0,15} +null ::= "null" space +number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space +object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space +root ::= object +space ::= | " " | "\n" [ \t]{0,20} +string ::= "\"" char* "\"" space +value ::= object | array | string | number | boolean | null`, + }, + { + name: "invalid schema with circular reference", + schema: JsonSchema{ + Type: "object", + Properties: map[string]any{ + "self": map[string]any{ + "$ref": "#", // Self reference + }, + }, + }, + expected: "", // Should return empty string for invalid schema + }, + { + name: "schema with invalid type", + schema: JsonSchema{ + Type: "invalid_type", // Invalid type + Properties: map[string]any{ + "foo": map[string]any{ + "type": "string", + }, + }, + }, + expected: "", // Should return empty string for invalid schema + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := tc.schema.AsGrammar() + if !strings.EqualFold(strings.TrimSpace(result), strings.TrimSpace(tc.expected)) { + if diff := cmp.Diff(tc.expected, result); diff != "" { + t.Fatalf("grammar mismatch (-want +got):\n%s", diff) + } + } + }) + } +} diff --git a/llama/sampling_ext.cpp b/llama/sampling_ext.cpp index 3dd7edf49..469c7b6b6 100644 --- a/llama/sampling_ext.cpp +++ b/llama/sampling_ext.cpp @@ -1,11 +1,13 @@ // TODO: this is a temporary wrapper to allow calling C++ code from CGo #include "sampling.h" #include "sampling_ext.h" +#include "json-schema-to-grammar.h" struct gpt_sampler *gpt_sampler_cinit( const struct llama_model *model, struct gpt_sampler_cparams *params) { - try { + try + { gpt_sampler_params sparams; sparams.top_k = params->top_k; sparams.top_p = params->top_p; @@ -24,7 +26,9 @@ struct gpt_sampler *gpt_sampler_cinit( sparams.seed = params->seed; sparams.grammar = params->grammar; return gpt_sampler_init(model, sparams); - } catch (const std::exception & err) { + } + catch (const std::exception &err) + { return nullptr; } } @@ -54,3 +58,24 @@ void gpt_sampler_caccept( { gpt_sampler_accept(sampler, id, apply_grammar); } + +int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len) +{ + try + { + nlohmann::json schema = nlohmann::json::parse(json_schema); + std::string grammar_str = json_schema_to_grammar(schema); + size_t len = grammar_str.length(); + if (len >= max_len) + { + len = max_len - 1; + } + strncpy(grammar, grammar_str.c_str(), len); + return len; + } + catch (const std::exception &e) + { + strncpy(grammar, "", max_len - 1); + return 0; + } +} diff --git a/llama/sampling_ext.h b/llama/sampling_ext.h index ec919a488..db868c2c2 100644 --- a/llama/sampling_ext.h +++ b/llama/sampling_ext.h @@ -47,6 +47,8 @@ extern "C" llama_token id, bool apply_grammar); + int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len); + #ifdef __cplusplus } #endif diff --git a/llm/server.go b/llm/server.go index debdd35e8..ad5306d7d 100644 --- a/llm/server.go +++ b/llm/server.go @@ -634,27 +634,22 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { const jsonGrammar = ` root ::= object value ::= object | array | string | number | ("true" | "false" | "null") ws - object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" ws - array ::= "[" ws ( value ("," ws value)* )? "]" ws - string ::= "\"" ( [^"\\\x7F\x00-\x1F] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes )* "\"" ws - number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws - # Optional space: by convention, applied in this grammar after literal chars when allowed ws ::= ([ \t\n] ws)? ` @@ -684,7 +679,7 @@ type completion struct { type CompletionRequest struct { Prompt string - Format string + Format json.RawMessage Images []ImageData Options *api.Options } @@ -749,10 +744,22 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return fmt.Errorf("unexpected server status: %s", status.ToString()) } - if req.Format == "json" { - request["grammar"] = jsonGrammar - if !strings.Contains(strings.ToLower(req.Prompt), "json") { - slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.") + // TODO (parthsareen): Move conversion to grammar with sampling logic + // API should do error handling for invalid formats + if req.Format != nil { + if strings.ToLower(strings.TrimSpace(string(req.Format))) == `"json"` { + request["grammar"] = jsonGrammar + if !strings.Contains(strings.ToLower(req.Prompt), "json") { + slog.Warn("prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.") + } + } else if schema, err := func() (llama.JsonSchema, error) { + var schema llama.JsonSchema + err := json.Unmarshal(req.Format, &schema) + return schema, err + }(); err == nil { + request["grammar"] = schema.AsGrammar() + } else { + slog.Warn(`format is neither a schema or "json"`, "format", req.Format) } } diff --git a/openai/openai.go b/openai/openai.go index bf1879f97..3a35d9dda 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -62,7 +62,12 @@ type Usage struct { } type ResponseFormat struct { - Type string `json:"type"` + Type string `json:"type"` + JsonSchema *JsonSchema `json:"json_schema,omitempty"` +} + +type JsonSchema struct { + Schema map[string]any `json:"schema"` } type EmbedRequest struct { @@ -482,9 +487,21 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { options["top_p"] = 1.0 } - var format string - if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" { - format = "json" + var format json.RawMessage + if r.ResponseFormat != nil { + switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) { + // Support the old "json_object" type for OpenAI compatibility + case "json_object": + format = json.RawMessage(`"json"`) + case "json_schema": + if r.ResponseFormat.JsonSchema != nil { + schema, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema) + if err != nil { + return nil, fmt.Errorf("failed to marshal json schema: %w", err) + } + format = schema + } + } } return &api.ChatRequest{ diff --git a/openai/openai_test.go b/openai/openai_test.go index e17037dea..0c2a7d806 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -13,6 +13,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" ) @@ -107,7 +108,7 @@ func TestChatMiddleware(t *testing.T) { "presence_penalty": 5.0, "top_p": 6.0, }, - Format: "json", + Format: json.RawMessage(`"json"`), Stream: &True, }, }, @@ -316,13 +317,13 @@ func TestChatMiddleware(t *testing.T) { if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { t.Fatal(err) } + return } - if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { - t.Fatal("requests did not match") + if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" { + t.Fatalf("requests did not match: %+v", diff) } - - if !reflect.DeepEqual(tc.err, errResp) { - t.Fatal("errors did not match") + if diff := cmp.Diff(tc.err, errResp); diff != "" { + t.Fatalf("errors did not match for %s:\n%s", tc.name, diff) } }) } diff --git a/server/routes.go b/server/routes.go index b8980a65e..bab29757e 100644 --- a/server/routes.go +++ b/server/routes.go @@ -278,7 +278,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, - Format: req.Format, + Format: json.RawMessage(req.Format), Options: opts, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{