mirror of
https://github.com/ollama/ollama.git
synced 2025-04-06 02:48:39 +02:00
api: structured outputs - chat endpoint (#7900)
Adds structured outputs to chat endpoint --------- Co-authored-by: Michael Yang <mxyng@pm.me> Co-authored-by: Hieu Nguyen <hieunguyen1053@outlook.com>
This commit is contained in:
parent
eb8366d658
commit
630e7dc6ff
@ -94,7 +94,7 @@ type ChatRequest struct {
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
|
||||
// Format is the format to return the response in (e.g. "json").
|
||||
Format string `json:"format"`
|
||||
Format json.RawMessage `json:"format,omitempty"`
|
||||
|
||||
// KeepAlive controls how long the model will stay loaded into memory
|
||||
// following the request.
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -1038,7 +1039,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
req := &api.ChatRequest{
|
||||
Model: opts.Model,
|
||||
Messages: opts.Messages,
|
||||
Format: opts.Format,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
Options: opts.Options,
|
||||
}
|
||||
|
||||
|
@ -85,9 +85,12 @@ COMPILER inline get_compiler() {
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"runtime"
|
||||
"runtime/cgo"
|
||||
"slices"
|
||||
@ -699,3 +702,33 @@ func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
|
||||
func (s *SamplingContext) Accept(id int, applyGrammar bool) {
|
||||
C.gpt_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
|
||||
}
|
||||
|
||||
type JsonSchema struct {
|
||||
Defs map[string]any `json:"$defs,omitempty"`
|
||||
Properties map[string]any `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Title string `json:"title,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
func (js JsonSchema) AsGrammar() string {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(js); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
cStr := C.CString(b.String())
|
||||
defer C.free(unsafe.Pointer(cStr))
|
||||
|
||||
// Allocate buffer for grammar output with reasonable size
|
||||
const maxLen = 32768 // 32KB
|
||||
buf := make([]byte, maxLen)
|
||||
|
||||
// Call C function to convert schema to grammar
|
||||
length := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
|
||||
if length == 0 {
|
||||
slog.Warn("unable to convert schema to grammar")
|
||||
}
|
||||
|
||||
return string(buf[:length])
|
||||
}
|
||||
|
@ -1 +1,70 @@
|
||||
package llama
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestJsonSchema(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
schema JsonSchema
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty schema",
|
||||
schema: JsonSchema{
|
||||
Type: "object",
|
||||
},
|
||||
expected: `array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= object
|
||||
space ::= | " " | "\n" [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
value ::= object | array | string | number | boolean | null`,
|
||||
},
|
||||
{
|
||||
name: "invalid schema with circular reference",
|
||||
schema: JsonSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"self": map[string]any{
|
||||
"$ref": "#", // Self reference
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "", // Should return empty string for invalid schema
|
||||
},
|
||||
{
|
||||
name: "schema with invalid type",
|
||||
schema: JsonSchema{
|
||||
Type: "invalid_type", // Invalid type
|
||||
Properties: map[string]any{
|
||||
"foo": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "", // Should return empty string for invalid schema
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := tc.schema.AsGrammar()
|
||||
if !strings.EqualFold(strings.TrimSpace(result), strings.TrimSpace(tc.expected)) {
|
||||
if diff := cmp.Diff(tc.expected, result); diff != "" {
|
||||
t.Fatalf("grammar mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
29
llama/sampling_ext.cpp
vendored
29
llama/sampling_ext.cpp
vendored
@ -1,11 +1,13 @@
|
||||
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
|
||||
#include "sampling.h"
|
||||
#include "sampling_ext.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
|
||||
struct gpt_sampler *gpt_sampler_cinit(
|
||||
const struct llama_model *model, struct gpt_sampler_cparams *params)
|
||||
{
|
||||
try {
|
||||
try
|
||||
{
|
||||
gpt_sampler_params sparams;
|
||||
sparams.top_k = params->top_k;
|
||||
sparams.top_p = params->top_p;
|
||||
@ -24,7 +26,9 @@ struct gpt_sampler *gpt_sampler_cinit(
|
||||
sparams.seed = params->seed;
|
||||
sparams.grammar = params->grammar;
|
||||
return gpt_sampler_init(model, sparams);
|
||||
} catch (const std::exception & err) {
|
||||
}
|
||||
catch (const std::exception &err)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
@ -54,3 +58,24 @@ void gpt_sampler_caccept(
|
||||
{
|
||||
gpt_sampler_accept(sampler, id, apply_grammar);
|
||||
}
|
||||
|
||||
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len)
|
||||
{
|
||||
try
|
||||
{
|
||||
nlohmann::json schema = nlohmann::json::parse(json_schema);
|
||||
std::string grammar_str = json_schema_to_grammar(schema);
|
||||
size_t len = grammar_str.length();
|
||||
if (len >= max_len)
|
||||
{
|
||||
len = max_len - 1;
|
||||
}
|
||||
strncpy(grammar, grammar_str.c_str(), len);
|
||||
return len;
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
strncpy(grammar, "", max_len - 1);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
2
llama/sampling_ext.h
vendored
2
llama/sampling_ext.h
vendored
@ -47,6 +47,8 @@ extern "C"
|
||||
llama_token id,
|
||||
bool apply_grammar);
|
||||
|
||||
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@ -634,27 +634,22 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||
const jsonGrammar = `
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\\x7F\x00-\x1F] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
`
|
||||
@ -684,7 +679,7 @@ type completion struct {
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string
|
||||
Format string
|
||||
Format json.RawMessage
|
||||
Images []ImageData
|
||||
Options *api.Options
|
||||
}
|
||||
@ -749,10 +744,22 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
if req.Format == "json" {
|
||||
request["grammar"] = jsonGrammar
|
||||
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
|
||||
slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
|
||||
// TODO (parthsareen): Move conversion to grammar with sampling logic
|
||||
// API should do error handling for invalid formats
|
||||
if req.Format != nil {
|
||||
if strings.ToLower(strings.TrimSpace(string(req.Format))) == `"json"` {
|
||||
request["grammar"] = jsonGrammar
|
||||
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
|
||||
slog.Warn("prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
|
||||
}
|
||||
} else if schema, err := func() (llama.JsonSchema, error) {
|
||||
var schema llama.JsonSchema
|
||||
err := json.Unmarshal(req.Format, &schema)
|
||||
return schema, err
|
||||
}(); err == nil {
|
||||
request["grammar"] = schema.AsGrammar()
|
||||
} else {
|
||||
slog.Warn(`format is neither a schema or "json"`, "format", req.Format)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -62,7 +62,12 @@ type Usage struct {
|
||||
}
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type"`
|
||||
Type string `json:"type"`
|
||||
JsonSchema *JsonSchema `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
type JsonSchema struct {
|
||||
Schema map[string]any `json:"schema"`
|
||||
}
|
||||
|
||||
type EmbedRequest struct {
|
||||
@ -482,9 +487,21 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
options["top_p"] = 1.0
|
||||
}
|
||||
|
||||
var format string
|
||||
if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
|
||||
format = "json"
|
||||
var format json.RawMessage
|
||||
if r.ResponseFormat != nil {
|
||||
switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
|
||||
// Support the old "json_object" type for OpenAI compatibility
|
||||
case "json_object":
|
||||
format = json.RawMessage(`"json"`)
|
||||
case "json_schema":
|
||||
if r.ResponseFormat.JsonSchema != nil {
|
||||
schema, err := json.Marshal(r.ResponseFormat.JsonSchema.Schema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal json schema: %w", err)
|
||||
}
|
||||
format = schema
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &api.ChatRequest{
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@ -107,7 +108,7 @@ func TestChatMiddleware(t *testing.T) {
|
||||
"presence_penalty": 5.0,
|
||||
"top_p": 6.0,
|
||||
},
|
||||
Format: "json",
|
||||
Format: json.RawMessage(`"json"`),
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
@ -316,13 +317,13 @@ func TestChatMiddleware(t *testing.T) {
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||
t.Fatal("requests did not match")
|
||||
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||
t.Fatalf("requests did not match: %+v", diff)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.err, errResp) {
|
||||
t.Fatal("errors did not match")
|
||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||
t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -278,7 +278,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Format: json.RawMessage(req.Format),
|
||||
Options: opts,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
res := api.GenerateResponse{
|
||||
|
Loading…
x
Reference in New Issue
Block a user