diff --git a/llm/server.go b/llm/server.go index 832863720..bb9062adc 100644 --- a/llm/server.go +++ b/llm/server.go @@ -700,20 +700,24 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu } if len(req.Format) > 0 { - switch { - case bytes.Equal(req.Format, []byte(`""`)): - // fallthrough - case bytes.Equal(req.Format, []byte(`"json"`)): + switch string(req.Format) { + case `null`, `""`: + // Field was set, but "missing" a value. We accept + // these as "not set". + break + case `"json"`: request["grammar"] = grammarJSON - case bytes.HasPrefix(req.Format, []byte("{")): + default: + if req.Format[0] != '{' { + return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format) + } + // User provided a JSON schema g := llama.SchemaToGrammar(req.Format) if g == nil { return fmt.Errorf("invalid JSON schema in format") } request["grammar"] = string(g) - default: - return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema", req.Format) } } diff --git a/llm/server_test.go b/llm/server_test.go index e6f79a585..6c8f7590b 100644 --- a/llm/server_test.go +++ b/llm/server_test.go @@ -39,25 +39,34 @@ func TestLLMServerCompletionFormat(t *testing.T) { cancel() // prevent further processing if request makes it past the format check - checkCanceled := func(err error) { + checkValid := func(err error) { t.Helper() if !errors.Is(err, context.Canceled) { t.Fatalf("Completion: err = %v; expected context.Canceled", err) } } - valids := []string{`"json"`, `{"type":"object"}`, ``, `""`} + valids := []string{ + // "missing" + ``, + `""`, + `null`, + + // JSON + `"json"`, + `{"type":"object"}`, + } for _, valid := range valids { err := s.Completion(ctx, CompletionRequest{ Options: new(api.Options), Format: []byte(valid), }, nil) - checkCanceled(err) + checkValid(err) } err := s.Completion(ctx, CompletionRequest{ Options: new(api.Options), Format: nil, // missing format }, nil) - checkCanceled(err) + checkValid(err) }