wip: apply gbnf vocab to logits

This commit is contained in:
Bruce MacDonald 2025-03-06 21:44:52 -08:00
parent 05a01fdecb
commit 81888abbe4
7 changed files with 392 additions and 2 deletions

135
llama/grammar.go Normal file
View File

@ -0,0 +1,135 @@
package llama
/*
#cgo CFLAGS: -std=c11
#cgo CXXFLAGS: -std=c++17
#cgo CPPFLAGS: -I${SRCDIR}/../llama/llama.cpp/include
#cgo CPPFLAGS: -I${SRCDIR}/../llama/llama.cpp/common
#cgo CPPFLAGS: -I${SRCDIR}/../llama/llama.cpp/src
#cgo CPPFLAGS: -I${SRCDIR}
#include <stdlib.h>
#include <stdbool.h>
#include "llama.h"
#include "grammar_ext.h"
// Helper function to handle Go string arrays to C
static char** makeCharArray(int size) {
return (char**)malloc(size * sizeof(char*));
}
static void setArrayString(char** a, int i, const char* s) {
a[i] = (char*)s;
}
static void freeCharArray(char** a, int size) {
free(a);
}
*/
import "C"
import (
"errors"
"runtime"
"unsafe"
)
// Grammar represents the interface for grammar-based sampling
type Grammar interface {
Apply(logits []float32) ([]float32, error)
Close() error
}
// CGrammar is a wrapper around the C++ grammar implementation
type CGrammar struct {
grammar *C.struct_llama_grammar
model *C.struct_llama_model
closed bool
}
// NewGrammarWithTokens creates a new grammar using a custom vocabulary defined by tokens
func NewGrammarWithTokens(grammarStr, grammarRoot string, tokens []string) (Grammar, error) {
if grammarStr == "" {
return nil, errors.New("empty grammar string")
}
if len(tokens) == 0 {
return nil, errors.New("empty token list")
}
// Create C array of strings for tokens
cTokens := C.makeCharArray(C.int(len(tokens)))
defer C.freeCharArray(cTokens, C.int(len(tokens)))
// Convert Go strings to C strings and set them in the array
cStrings := make([]*C.char, len(tokens))
for i, token := range tokens {
cStrings[i] = C.CString(token)
C.setArrayString(cTokens, C.int(i), cStrings[i])
}
// Create vocabulary from tokens
cVocab := C.vocab_bridge_from_tokens((**C.char)(unsafe.Pointer(cTokens)), C.int(len(tokens)))
// Free the C strings after creating the vocab
for _, str := range cStrings {
C.free(unsafe.Pointer(str))
}
if cVocab == nil {
return nil, errors.New("failed to create vocabulary from tokens")
}
// Make sure to free the vocabulary when we're done
defer C.vocab_bridge_free(cVocab)
cGrammarStr := C.CString(grammarStr)
defer C.free(unsafe.Pointer(cGrammarStr))
cGrammarRoot := C.CString(grammarRoot)
defer C.free(unsafe.Pointer(cGrammarRoot))
// Create grammar using our C wrapper function with the correct signature
grammar := C.grammar_create_from_string(cVocab, cGrammarStr, cGrammarRoot)
if grammar == nil {
return nil, errors.New("failed to initialize grammar")
}
cg := &CGrammar{
grammar: grammar,
closed: false,
}
// Set up finalizer to free resources when the object is garbage collected
runtime.SetFinalizer(cg, func(g *CGrammar) {
g.Close()
})
return cg, nil
}
// Apply applies grammar constraints to logits
func (g *CGrammar) Apply(logits []float32) ([]float32, error) {
if g.closed || g.grammar == nil {
return nil, errors.New("grammar not initialized or already closed")
}
// Create a copy of logits to modify
result := make([]float32, len(logits))
copy(result, logits)
// Apply grammar constraints using our C wrapper function
C.grammar_apply_to_logits(g.grammar, (*C.float)(&result[0]), C.int(len(result)))
return result, nil
}
// Close releases resources associated with the grammar
func (g *CGrammar) Close() error {
if !g.closed && g.grammar != nil {
C.grammar_free(g.grammar)
g.grammar = nil
g.closed = true
}
return nil
}

83
llama/grammar_ext.cpp vendored Normal file
View File

@ -0,0 +1,83 @@
#include <stdlib.h>
#include <string>
#include <vector>
#include <cstdint>
#include <stdexcept>
#include "llama-sampling.h"
#include "llama-grammar.h"
#include "llama-vocab.h"
#include "grammar_ext.h"
extern "C" {
struct llama_grammar* grammar_create_from_string(const struct llama_vocab* vocab, const char* grammar_str, const char* grammar_root) {
try {
// Initialize grammar sampler directly with the model
struct llama_sampler* sampler = llama_sampler_init_grammar(vocab, grammar_str, grammar_root);
if (!sampler) {
return nullptr;
}
// Cast the sampler to a grammar and return it
return (struct llama_grammar*)sampler;
} catch (const std::exception &err) {
return nullptr;
}
}
void grammar_apply_to_logits(struct llama_grammar* grammar, float* logits, int n_logits) {
if (!grammar || !logits || n_logits <= 0) {
return;
}
// Create token data array for the grammar application
llama_token_data* token_data = (llama_token_data*)malloc(n_logits * sizeof(llama_token_data));
if (!token_data) {
return;
}
// Initialize token data from logits
for (int i = 0; i < n_logits; i++) {
token_data[i].id = i;
token_data[i].logit = logits[i];
token_data[i].p = 0.0f;
}
// Create token data array structure
llama_token_data_array arr = {
.data = token_data,
.size = (size_t)n_logits,
.sorted = false,
.selected = -1
};
// Apply grammar constraints to the token data array
llama_grammar_apply_impl(*grammar, &arr);
// Copy back the modified logits
for (int i = 0; i < n_logits; i++) {
logits[i] = token_data[i].logit;
}
free(token_data);
}
void grammar_free(struct llama_grammar* grammar) {
if (grammar) {
// Free the grammar as a sampler
llama_sampler_free((struct llama_sampler*)grammar);
}
}
struct llama_vocab* vocab_bridge_from_tokens(const char** tokens, int n_tokens) {
// Call the C++ function from llama-vocab.cpp
return llama_vocab_from_tokens(tokens, n_tokens);
}
void vocab_bridge_free(struct llama_vocab* vocab) {
// Call the C++ function from llama-vocab.cpp
llama_vocab_free(vocab);
}
} // extern "C"

33
llama/grammar_ext.h vendored Normal file
View File

@ -0,0 +1,33 @@
#ifndef GRAMMAR_EXT_H
#define GRAMMAR_EXT_H
#include "llama.h"
#ifdef __cplusplus
extern "C" {
#endif
// Forward declarations
struct llama_grammar;
struct llama_vocab;
// Create a new grammar from a string (returns a grammar implemented as a sampler)
struct llama_grammar* grammar_create_from_string(const struct llama_vocab* vocab, const char* grammar_str, const char* grammar_root);
// Apply grammar constraints to logits
void grammar_apply_to_logits(struct llama_grammar* grammar, float* logits, int n_logits);
// Free grammar resources (frees the underlying sampler)
void grammar_free(struct llama_grammar* grammar);
// C wrapper for llama_vocab_from_tokens
struct llama_vocab* vocab_bridge_from_tokens(const char** tokens, int n_tokens);
// C wrapper for llama_vocab_free
void vocab_bridge_free(struct llama_vocab* vocab);
#ifdef __cplusplus
}
#endif
#endif // GRAMMAR_EXT_H

View File

@ -18,6 +18,7 @@ package llama
#include "mllama.h"
#include "sampling_ext.h"
#include "grammar_ext.h"
extern bool llamaProgressCallback(float progress, void *user_data);
extern void llamaLog(int level, char* text, void* user_data);

View File

@ -0,0 +1,117 @@
From 668a974433edccf2c5fcc2192c39aed601e575f2 Mon Sep 17 00:00:00 2001
From: Bruce MacDonald <brucewmacdonald@gmail.com>
Date: Thu, 6 Mar 2025 21:07:06 -0800
Subject: [PATCH] expose llama_vocab from tokens
---
llama/llama.cpp/src/llama-vocab.cpp | 73 +++++++++++++++++++++++++++++
llama/llama.cpp/src/llama-vocab.h | 11 ++++-
2 files changed, 83 insertions(+), 1 deletion(-)
diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp
index c7ff28be..ad6e7ad8 100644
--- a/llama/llama.cpp/src/llama-vocab.cpp
+++ b/llama/llama.cpp/src/llama-vocab.cpp
@@ -3253,3 +3253,76 @@ int32_t llama_detokenize(
return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
}
+struct llama_vocab *llama_vocab_from_tokens(const char **tokens, int n_tokens)
+{
+ if (!tokens || n_tokens <= 0)
+ {
+ return nullptr;
+ }
+
+ try
+ {
+ // Create a new vocabulary instance
+ llama_vocab *vocab = new llama_vocab();
+ vocab->pimpl = std::make_unique<llama_vocab::impl>(*vocab);
+
+ // Resize the token data vectors
+ vocab->pimpl->id_to_token.resize(n_tokens);
+
+ // Create mappings for all tokens
+ for (int i = 0; i < n_tokens; i++)
+ {
+ std::string word = tokens[i];
+ if (word.empty())
+ {
+ word = "[EMPTY_" + std::to_string(i) + "]";
+ }
+
+ // Add to token mappings
+ vocab->pimpl->token_to_id[word] = i;
+
+ // Set up token data
+ auto &token_data = vocab->pimpl->id_to_token[i];
+ token_data.text = std::move(word);
+ token_data.score = 0.0f; // Default score
+ token_data.attr = LLAMA_TOKEN_ATTR_NORMAL;
+
+ // Detect special tokens
+ if (word == "<s>" || word == "<bos>")
+ {
+ vocab->pimpl->special_bos_id = i;
+ }
+ else if (word == "</s>" || word == "<eos>" || word == "<|endoftext|>")
+ {
+ vocab->pimpl->special_eos_id = i;
+ vocab->pimpl->special_eog_ids.insert(i);
+ }
+ else if (word == "<unk>")
+ {
+ vocab->pimpl->special_unk_id = i;
+ }
+ }
+
+ // Initialize the token-to-piece cache
+ vocab->pimpl->cache_token_to_piece.resize(n_tokens);
+ for (int i = 0; i < n_tokens; i++)
+ {
+ vocab->pimpl->cache_token_to_piece[i] = vocab->pimpl->id_to_token[i].text;
+ }
+
+ return vocab;
+ }
+ catch (const std::exception &err)
+ {
+ return nullptr;
+ }
+}
+
+// Helper function to free the vocab
+void llama_vocab_free(struct llama_vocab *vocab)
+{
+ if (vocab)
+ {
+ delete vocab;
+ }
+}
\ No newline at end of file
diff --git a/llama/llama.cpp/src/llama-vocab.h b/llama/llama.cpp/src/llama-vocab.h
index 5ce35521..eceb28f3 100644
--- a/llama/llama.cpp/src/llama-vocab.h
+++ b/llama/llama.cpp/src/llama-vocab.h
@@ -119,7 +119,16 @@ struct llama_vocab {
void print_info() const;
-private:
struct impl;
std::unique_ptr<impl> pimpl;
};
+
+// Create a vocabulary from an array of token strings
+// tokens: Array of token strings
+// n_tokens: Number of tokens in the array
+// Returns: A new llama_vocab instance, or nullptr on failure
+// The caller is responsible for freeing the vocabulary using llama_vocab_free
+LLAMA_API struct llama_vocab * llama_vocab_from_tokens(const char ** tokens, int n_tokens);
+
+// Free a vocabulary created with llama_vocab_from_tokens
+LLAMA_API void llama_vocab_free(struct llama_vocab * vocab);
--
2.39.3 (Apple Git-145)

View File

@ -428,7 +428,8 @@ func (s *Server) processBatch() error {
// sample a token
vocabSize := len(logits) / len(options.Outputs)
// TODO: need access to vocab to apply grammar
// token = sampler.Grammar.Apply(logits)
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
if err != nil {
return fmt.Errorf("failed to sample token: %w", err)
@ -575,6 +576,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
// TODO: if grammar is provided, load it
// if req.Grammar != "" {
// grammar := llama.NewGrammarWithTokens(req.Grammar, "root", s.model.Vocabulary)
// }
// defer grammar.Close()
// sampler := sample.WithGrammar(sample.Greedy(), grammar)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict,
stop: req.Stop,

View File

@ -4,6 +4,7 @@ import (
"errors"
"math"
"github.com/ollama/ollama/llama"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/stat/sampleuv"
)
@ -57,12 +58,24 @@ func (s weighted) Sample(logits []float32) (int32, error) {
return -1, errors.New("weighted sampler failed, no valid token found")
}
type greedy struct{}
type greedy struct {
grammar llama.Grammar
}
func Greedy() Sampler {
return greedy{}
}
func WithGrammar(s Sampler, grammar llama.Grammar) Sampler {
switch t := s.(type) {
case greedy:
t.grammar = grammar
return t
default:
return s
}
}
// Sample returns the index of the maximum value in logits.
func (s greedy) Sample(logits []float32) (int32, error) {
if len(logits) == 0 {