From 81888abbe42ec668242fd9e68a5f9391542469b6 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Thu, 6 Mar 2025 21:44:52 -0800 Subject: [PATCH] wip: apply gbnf vocab to logits --- llama/grammar.go | 135 ++++++++++++++++++ llama/grammar_ext.cpp | 83 +++++++++++ llama/grammar_ext.h | 33 +++++ llama/llama.go | 1 + .../0019-expose-llama_vocab-from-tokens.patch | 117 +++++++++++++++ runner/ollamarunner/runner.go | 10 +- sample/samplers.go | 15 +- 7 files changed, 392 insertions(+), 2 deletions(-) create mode 100644 llama/grammar.go create mode 100644 llama/grammar_ext.cpp create mode 100644 llama/grammar_ext.h create mode 100644 llama/patches/0019-expose-llama_vocab-from-tokens.patch diff --git a/llama/grammar.go b/llama/grammar.go new file mode 100644 index 000000000..2ffc7bc82 --- /dev/null +++ b/llama/grammar.go @@ -0,0 +1,135 @@ +package llama + +/* +#cgo CFLAGS: -std=c11 +#cgo CXXFLAGS: -std=c++17 +#cgo CPPFLAGS: -I${SRCDIR}/../llama/llama.cpp/include +#cgo CPPFLAGS: -I${SRCDIR}/../llama/llama.cpp/common +#cgo CPPFLAGS: -I${SRCDIR}/../llama/llama.cpp/src +#cgo CPPFLAGS: -I${SRCDIR} + +#include +#include +#include "llama.h" +#include "grammar_ext.h" + +// Helper function to handle Go string arrays to C +static char** makeCharArray(int size) { + return (char**)malloc(size * sizeof(char*)); +} + +static void setArrayString(char** a, int i, const char* s) { + a[i] = (char*)s; +} + +static void freeCharArray(char** a, int size) { + free(a); +} +*/ +import "C" + +import ( + "errors" + "runtime" + "unsafe" +) + +// Grammar represents the interface for grammar-based sampling +type Grammar interface { + Apply(logits []float32) ([]float32, error) + Close() error +} + +// CGrammar is a wrapper around the C++ grammar implementation +type CGrammar struct { + grammar *C.struct_llama_grammar + model *C.struct_llama_model + closed bool +} + +// NewGrammarWithTokens creates a new grammar using a custom vocabulary defined by tokens +func NewGrammarWithTokens(grammarStr, grammarRoot string, tokens []string) (Grammar, error) { + if grammarStr == "" { + return nil, errors.New("empty grammar string") + } + + if len(tokens) == 0 { + return nil, errors.New("empty token list") + } + + // Create C array of strings for tokens + cTokens := C.makeCharArray(C.int(len(tokens))) + defer C.freeCharArray(cTokens, C.int(len(tokens))) + + // Convert Go strings to C strings and set them in the array + cStrings := make([]*C.char, len(tokens)) + for i, token := range tokens { + cStrings[i] = C.CString(token) + C.setArrayString(cTokens, C.int(i), cStrings[i]) + } + + // Create vocabulary from tokens + cVocab := C.vocab_bridge_from_tokens((**C.char)(unsafe.Pointer(cTokens)), C.int(len(tokens))) + + // Free the C strings after creating the vocab + for _, str := range cStrings { + C.free(unsafe.Pointer(str)) + } + + if cVocab == nil { + return nil, errors.New("failed to create vocabulary from tokens") + } + + // Make sure to free the vocabulary when we're done + defer C.vocab_bridge_free(cVocab) + + cGrammarStr := C.CString(grammarStr) + defer C.free(unsafe.Pointer(cGrammarStr)) + + cGrammarRoot := C.CString(grammarRoot) + defer C.free(unsafe.Pointer(cGrammarRoot)) + + // Create grammar using our C wrapper function with the correct signature + grammar := C.grammar_create_from_string(cVocab, cGrammarStr, cGrammarRoot) + if grammar == nil { + return nil, errors.New("failed to initialize grammar") + } + + cg := &CGrammar{ + grammar: grammar, + closed: false, + } + + // Set up finalizer to free resources when the object is garbage collected + runtime.SetFinalizer(cg, func(g *CGrammar) { + g.Close() + }) + + return cg, nil +} + +// Apply applies grammar constraints to logits +func (g *CGrammar) Apply(logits []float32) ([]float32, error) { + if g.closed || g.grammar == nil { + return nil, errors.New("grammar not initialized or already closed") + } + + // Create a copy of logits to modify + result := make([]float32, len(logits)) + copy(result, logits) + + // Apply grammar constraints using our C wrapper function + C.grammar_apply_to_logits(g.grammar, (*C.float)(&result[0]), C.int(len(result))) + + return result, nil +} + +// Close releases resources associated with the grammar +func (g *CGrammar) Close() error { + if !g.closed && g.grammar != nil { + C.grammar_free(g.grammar) + g.grammar = nil + g.closed = true + } + return nil +} diff --git a/llama/grammar_ext.cpp b/llama/grammar_ext.cpp new file mode 100644 index 000000000..f1ad3a913 --- /dev/null +++ b/llama/grammar_ext.cpp @@ -0,0 +1,83 @@ +#include +#include +#include +#include +#include + +#include "llama-sampling.h" +#include "llama-grammar.h" +#include "llama-vocab.h" +#include "grammar_ext.h" + +extern "C" { + +struct llama_grammar* grammar_create_from_string(const struct llama_vocab* vocab, const char* grammar_str, const char* grammar_root) { + try { + // Initialize grammar sampler directly with the model + struct llama_sampler* sampler = llama_sampler_init_grammar(vocab, grammar_str, grammar_root); + if (!sampler) { + return nullptr; + } + + // Cast the sampler to a grammar and return it + return (struct llama_grammar*)sampler; + } catch (const std::exception &err) { + return nullptr; + } +} + +void grammar_apply_to_logits(struct llama_grammar* grammar, float* logits, int n_logits) { + if (!grammar || !logits || n_logits <= 0) { + return; + } + + // Create token data array for the grammar application + llama_token_data* token_data = (llama_token_data*)malloc(n_logits * sizeof(llama_token_data)); + if (!token_data) { + return; + } + + // Initialize token data from logits + for (int i = 0; i < n_logits; i++) { + token_data[i].id = i; + token_data[i].logit = logits[i]; + token_data[i].p = 0.0f; + } + + // Create token data array structure + llama_token_data_array arr = { + .data = token_data, + .size = (size_t)n_logits, + .sorted = false, + .selected = -1 + }; + + // Apply grammar constraints to the token data array + llama_grammar_apply_impl(*grammar, &arr); + + // Copy back the modified logits + for (int i = 0; i < n_logits; i++) { + logits[i] = token_data[i].logit; + } + + free(token_data); +} + +void grammar_free(struct llama_grammar* grammar) { + if (grammar) { + // Free the grammar as a sampler + llama_sampler_free((struct llama_sampler*)grammar); + } +} + +struct llama_vocab* vocab_bridge_from_tokens(const char** tokens, int n_tokens) { + // Call the C++ function from llama-vocab.cpp + return llama_vocab_from_tokens(tokens, n_tokens); +} + +void vocab_bridge_free(struct llama_vocab* vocab) { + // Call the C++ function from llama-vocab.cpp + llama_vocab_free(vocab); +} + +} // extern "C" \ No newline at end of file diff --git a/llama/grammar_ext.h b/llama/grammar_ext.h new file mode 100644 index 000000000..468e7dae7 --- /dev/null +++ b/llama/grammar_ext.h @@ -0,0 +1,33 @@ +#ifndef GRAMMAR_EXT_H +#define GRAMMAR_EXT_H + +#include "llama.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// Forward declarations +struct llama_grammar; +struct llama_vocab; + +// Create a new grammar from a string (returns a grammar implemented as a sampler) +struct llama_grammar* grammar_create_from_string(const struct llama_vocab* vocab, const char* grammar_str, const char* grammar_root); + +// Apply grammar constraints to logits +void grammar_apply_to_logits(struct llama_grammar* grammar, float* logits, int n_logits); + +// Free grammar resources (frees the underlying sampler) +void grammar_free(struct llama_grammar* grammar); + +// C wrapper for llama_vocab_from_tokens +struct llama_vocab* vocab_bridge_from_tokens(const char** tokens, int n_tokens); + +// C wrapper for llama_vocab_free +void vocab_bridge_free(struct llama_vocab* vocab); + +#ifdef __cplusplus +} +#endif + +#endif // GRAMMAR_EXT_H \ No newline at end of file diff --git a/llama/llama.go b/llama/llama.go index bb5028bd9..a6ff05bd6 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -18,6 +18,7 @@ package llama #include "mllama.h" #include "sampling_ext.h" +#include "grammar_ext.h" extern bool llamaProgressCallback(float progress, void *user_data); extern void llamaLog(int level, char* text, void* user_data); diff --git a/llama/patches/0019-expose-llama_vocab-from-tokens.patch b/llama/patches/0019-expose-llama_vocab-from-tokens.patch new file mode 100644 index 000000000..e71c4f631 --- /dev/null +++ b/llama/patches/0019-expose-llama_vocab-from-tokens.patch @@ -0,0 +1,117 @@ +From 668a974433edccf2c5fcc2192c39aed601e575f2 Mon Sep 17 00:00:00 2001 +From: Bruce MacDonald +Date: Thu, 6 Mar 2025 21:07:06 -0800 +Subject: [PATCH] expose llama_vocab from tokens + +--- + llama/llama.cpp/src/llama-vocab.cpp | 73 +++++++++++++++++++++++++++++ + llama/llama.cpp/src/llama-vocab.h | 11 ++++- + 2 files changed, 83 insertions(+), 1 deletion(-) + +diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp +index c7ff28be..ad6e7ad8 100644 +--- a/llama/llama.cpp/src/llama-vocab.cpp ++++ b/llama/llama.cpp/src/llama-vocab.cpp +@@ -3253,3 +3253,76 @@ int32_t llama_detokenize( + return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special); + } + ++struct llama_vocab *llama_vocab_from_tokens(const char **tokens, int n_tokens) ++{ ++ if (!tokens || n_tokens <= 0) ++ { ++ return nullptr; ++ } ++ ++ try ++ { ++ // Create a new vocabulary instance ++ llama_vocab *vocab = new llama_vocab(); ++ vocab->pimpl = std::make_unique(*vocab); ++ ++ // Resize the token data vectors ++ vocab->pimpl->id_to_token.resize(n_tokens); ++ ++ // Create mappings for all tokens ++ for (int i = 0; i < n_tokens; i++) ++ { ++ std::string word = tokens[i]; ++ if (word.empty()) ++ { ++ word = "[EMPTY_" + std::to_string(i) + "]"; ++ } ++ ++ // Add to token mappings ++ vocab->pimpl->token_to_id[word] = i; ++ ++ // Set up token data ++ auto &token_data = vocab->pimpl->id_to_token[i]; ++ token_data.text = std::move(word); ++ token_data.score = 0.0f; // Default score ++ token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; ++ ++ // Detect special tokens ++ if (word == "" || word == "") ++ { ++ vocab->pimpl->special_bos_id = i; ++ } ++ else if (word == "" || word == "" || word == "<|endoftext|>") ++ { ++ vocab->pimpl->special_eos_id = i; ++ vocab->pimpl->special_eog_ids.insert(i); ++ } ++ else if (word == "") ++ { ++ vocab->pimpl->special_unk_id = i; ++ } ++ } ++ ++ // Initialize the token-to-piece cache ++ vocab->pimpl->cache_token_to_piece.resize(n_tokens); ++ for (int i = 0; i < n_tokens; i++) ++ { ++ vocab->pimpl->cache_token_to_piece[i] = vocab->pimpl->id_to_token[i].text; ++ } ++ ++ return vocab; ++ } ++ catch (const std::exception &err) ++ { ++ return nullptr; ++ } ++} ++ ++// Helper function to free the vocab ++void llama_vocab_free(struct llama_vocab *vocab) ++{ ++ if (vocab) ++ { ++ delete vocab; ++ } ++} +\ No newline at end of file +diff --git a/llama/llama.cpp/src/llama-vocab.h b/llama/llama.cpp/src/llama-vocab.h +index 5ce35521..eceb28f3 100644 +--- a/llama/llama.cpp/src/llama-vocab.h ++++ b/llama/llama.cpp/src/llama-vocab.h +@@ -119,7 +119,16 @@ struct llama_vocab { + + void print_info() const; + +-private: + struct impl; + std::unique_ptr pimpl; + }; ++ ++// Create a vocabulary from an array of token strings ++// tokens: Array of token strings ++// n_tokens: Number of tokens in the array ++// Returns: A new llama_vocab instance, or nullptr on failure ++// The caller is responsible for freeing the vocabulary using llama_vocab_free ++LLAMA_API struct llama_vocab * llama_vocab_from_tokens(const char ** tokens, int n_tokens); ++ ++// Free a vocabulary created with llama_vocab_from_tokens ++LLAMA_API void llama_vocab_free(struct llama_vocab * vocab); +-- +2.39.3 (Apple Git-145) + diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 1a4bbf19e..c2c396c0b 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -428,7 +428,8 @@ func (s *Server) processBatch() error { // sample a token vocabSize := len(logits) / len(options.Outputs) - + // TODO: need access to vocab to apply grammar + // token = sampler.Grammar.Apply(logits) token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize]) if err != nil { return fmt.Errorf("failed to sample token: %w", err) @@ -575,6 +576,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } + // TODO: if grammar is provided, load it + // if req.Grammar != "" { + // grammar := llama.NewGrammarWithTokens(req.Grammar, "root", s.model.Vocabulary) + // } + // defer grammar.Close() + // sampler := sample.WithGrammar(sample.Greedy(), grammar) + seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ numPredict: req.NumPredict, stop: req.Stop, diff --git a/sample/samplers.go b/sample/samplers.go index 1b8a5edd9..bb09e6955 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -4,6 +4,7 @@ import ( "errors" "math" + "github.com/ollama/ollama/llama" "golang.org/x/exp/rand" "gonum.org/v1/gonum/stat/sampleuv" ) @@ -57,12 +58,24 @@ func (s weighted) Sample(logits []float32) (int32, error) { return -1, errors.New("weighted sampler failed, no valid token found") } -type greedy struct{} +type greedy struct { + grammar llama.Grammar +} func Greedy() Sampler { return greedy{} } +func WithGrammar(s Sampler, grammar llama.Grammar) Sampler { + switch t := s.(type) { + case greedy: + t.grammar = grammar + return t + default: + return s + } +} + // Sample returns the index of the maximum value in logits. func (s greedy) Sample(logits []float32) (int32, error) { if len(logits) == 0 {