api: structured outputs - chat endpoint (#7900)

Adds structured outputs to chat endpoint
---------

Co-authored-by: Michael Yang <mxyng@pm.me>
Co-authored-by: Hieu Nguyen <hieunguyen1053@outlook.com>
This commit is contained in:
Parth Sareen
2024-12-04 16:31:19 -08:00
committed by GitHub
parent eb8366d658
commit 630e7dc6ff
10 changed files with 180 additions and 25 deletions

View File

@@ -85,9 +85,12 @@ COMPILER inline get_compiler() {
import "C"
import (
"bytes"
_ "embed"
"encoding/json"
"errors"
"fmt"
"log/slog"
"runtime"
"runtime/cgo"
"slices"
@@ -699,3 +702,33 @@ func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
func (s *SamplingContext) Accept(id int, applyGrammar bool) {
C.gpt_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
}
type JsonSchema struct {
Defs map[string]any `json:"$defs,omitempty"`
Properties map[string]any `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
Title string `json:"title,omitempty"`
Type string `json:"type,omitempty"`
}
func (js JsonSchema) AsGrammar() string {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(js); err != nil {
return ""
}
cStr := C.CString(b.String())
defer C.free(unsafe.Pointer(cStr))
// Allocate buffer for grammar output with reasonable size
const maxLen = 32768 // 32KB
buf := make([]byte, maxLen)
// Call C function to convert schema to grammar
length := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
if length == 0 {
slog.Warn("unable to convert schema to grammar")
}
return string(buf[:length])
}

View File

@@ -1 +1,70 @@
package llama
import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestJsonSchema(t *testing.T) {
testCases := []struct {
name string
schema JsonSchema
expected string
}{
{
name: "empty schema",
schema: JsonSchema{
Type: "object",
},
expected: `array ::= "[" space ( value ("," space value)* )? "]" space
boolean ::= ("true" | "false") space
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
null ::= "null" space
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
root ::= object
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
value ::= object | array | string | number | boolean | null`,
},
{
name: "invalid schema with circular reference",
schema: JsonSchema{
Type: "object",
Properties: map[string]any{
"self": map[string]any{
"$ref": "#", // Self reference
},
},
},
expected: "", // Should return empty string for invalid schema
},
{
name: "schema with invalid type",
schema: JsonSchema{
Type: "invalid_type", // Invalid type
Properties: map[string]any{
"foo": map[string]any{
"type": "string",
},
},
},
expected: "", // Should return empty string for invalid schema
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := tc.schema.AsGrammar()
if !strings.EqualFold(strings.TrimSpace(result), strings.TrimSpace(tc.expected)) {
if diff := cmp.Diff(tc.expected, result); diff != "" {
t.Fatalf("grammar mismatch (-want +got):\n%s", diff)
}
}
})
}
}

View File

@@ -1,11 +1,13 @@
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
#include "sampling.h"
#include "sampling_ext.h"
#include "json-schema-to-grammar.h"
struct gpt_sampler *gpt_sampler_cinit(
const struct llama_model *model, struct gpt_sampler_cparams *params)
{
try {
try
{
gpt_sampler_params sparams;
sparams.top_k = params->top_k;
sparams.top_p = params->top_p;
@@ -24,7 +26,9 @@ struct gpt_sampler *gpt_sampler_cinit(
sparams.seed = params->seed;
sparams.grammar = params->grammar;
return gpt_sampler_init(model, sparams);
} catch (const std::exception & err) {
}
catch (const std::exception &err)
{
return nullptr;
}
}
@@ -54,3 +58,24 @@ void gpt_sampler_caccept(
{
gpt_sampler_accept(sampler, id, apply_grammar);
}
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len)
{
try
{
nlohmann::json schema = nlohmann::json::parse(json_schema);
std::string grammar_str = json_schema_to_grammar(schema);
size_t len = grammar_str.length();
if (len >= max_len)
{
len = max_len - 1;
}
strncpy(grammar, grammar_str.c_str(), len);
return len;
}
catch (const std::exception &e)
{
strncpy(grammar, "", max_len - 1);
return 0;
}
}

View File

@@ -47,6 +47,8 @@ extern "C"
llama_token id,
bool apply_grammar);
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
#ifdef __cplusplus
}
#endif