mirror of
https://github.com/ollama/ollama.git
synced 2025-04-02 09:00:28 +02:00
wip: apply gbnf vocab to logits
This commit is contained in:
parent
05a01fdecb
commit
81888abbe4
135
llama/grammar.go
Normal file
135
llama/grammar.go
Normal file
@ -0,0 +1,135 @@
|
||||
package llama
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -std=c11
|
||||
#cgo CXXFLAGS: -std=c++17
|
||||
#cgo CPPFLAGS: -I${SRCDIR}/../llama/llama.cpp/include
|
||||
#cgo CPPFLAGS: -I${SRCDIR}/../llama/llama.cpp/common
|
||||
#cgo CPPFLAGS: -I${SRCDIR}/../llama/llama.cpp/src
|
||||
#cgo CPPFLAGS: -I${SRCDIR}
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include "llama.h"
|
||||
#include "grammar_ext.h"
|
||||
|
||||
// Helper function to handle Go string arrays to C
|
||||
static char** makeCharArray(int size) {
|
||||
return (char**)malloc(size * sizeof(char*));
|
||||
}
|
||||
|
||||
static void setArrayString(char** a, int i, const char* s) {
|
||||
a[i] = (char*)s;
|
||||
}
|
||||
|
||||
static void freeCharArray(char** a, int size) {
|
||||
free(a);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Grammar represents the interface for grammar-based sampling
|
||||
type Grammar interface {
|
||||
Apply(logits []float32) ([]float32, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// CGrammar is a wrapper around the C++ grammar implementation
|
||||
type CGrammar struct {
|
||||
grammar *C.struct_llama_grammar
|
||||
model *C.struct_llama_model
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewGrammarWithTokens creates a new grammar using a custom vocabulary defined by tokens
|
||||
func NewGrammarWithTokens(grammarStr, grammarRoot string, tokens []string) (Grammar, error) {
|
||||
if grammarStr == "" {
|
||||
return nil, errors.New("empty grammar string")
|
||||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
return nil, errors.New("empty token list")
|
||||
}
|
||||
|
||||
// Create C array of strings for tokens
|
||||
cTokens := C.makeCharArray(C.int(len(tokens)))
|
||||
defer C.freeCharArray(cTokens, C.int(len(tokens)))
|
||||
|
||||
// Convert Go strings to C strings and set them in the array
|
||||
cStrings := make([]*C.char, len(tokens))
|
||||
for i, token := range tokens {
|
||||
cStrings[i] = C.CString(token)
|
||||
C.setArrayString(cTokens, C.int(i), cStrings[i])
|
||||
}
|
||||
|
||||
// Create vocabulary from tokens
|
||||
cVocab := C.vocab_bridge_from_tokens((**C.char)(unsafe.Pointer(cTokens)), C.int(len(tokens)))
|
||||
|
||||
// Free the C strings after creating the vocab
|
||||
for _, str := range cStrings {
|
||||
C.free(unsafe.Pointer(str))
|
||||
}
|
||||
|
||||
if cVocab == nil {
|
||||
return nil, errors.New("failed to create vocabulary from tokens")
|
||||
}
|
||||
|
||||
// Make sure to free the vocabulary when we're done
|
||||
defer C.vocab_bridge_free(cVocab)
|
||||
|
||||
cGrammarStr := C.CString(grammarStr)
|
||||
defer C.free(unsafe.Pointer(cGrammarStr))
|
||||
|
||||
cGrammarRoot := C.CString(grammarRoot)
|
||||
defer C.free(unsafe.Pointer(cGrammarRoot))
|
||||
|
||||
// Create grammar using our C wrapper function with the correct signature
|
||||
grammar := C.grammar_create_from_string(cVocab, cGrammarStr, cGrammarRoot)
|
||||
if grammar == nil {
|
||||
return nil, errors.New("failed to initialize grammar")
|
||||
}
|
||||
|
||||
cg := &CGrammar{
|
||||
grammar: grammar,
|
||||
closed: false,
|
||||
}
|
||||
|
||||
// Set up finalizer to free resources when the object is garbage collected
|
||||
runtime.SetFinalizer(cg, func(g *CGrammar) {
|
||||
g.Close()
|
||||
})
|
||||
|
||||
return cg, nil
|
||||
}
|
||||
|
||||
// Apply applies grammar constraints to logits
|
||||
func (g *CGrammar) Apply(logits []float32) ([]float32, error) {
|
||||
if g.closed || g.grammar == nil {
|
||||
return nil, errors.New("grammar not initialized or already closed")
|
||||
}
|
||||
|
||||
// Create a copy of logits to modify
|
||||
result := make([]float32, len(logits))
|
||||
copy(result, logits)
|
||||
|
||||
// Apply grammar constraints using our C wrapper function
|
||||
C.grammar_apply_to_logits(g.grammar, (*C.float)(&result[0]), C.int(len(result)))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Close releases resources associated with the grammar
|
||||
func (g *CGrammar) Close() error {
|
||||
if !g.closed && g.grammar != nil {
|
||||
C.grammar_free(g.grammar)
|
||||
g.grammar = nil
|
||||
g.closed = true
|
||||
}
|
||||
return nil
|
||||
}
|
83
llama/grammar_ext.cpp
vendored
Normal file
83
llama/grammar_ext.cpp
vendored
Normal file
@ -0,0 +1,83 @@
|
||||
#include <stdlib.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "llama-sampling.h"
|
||||
#include "llama-grammar.h"
|
||||
#include "llama-vocab.h"
|
||||
#include "grammar_ext.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
struct llama_grammar* grammar_create_from_string(const struct llama_vocab* vocab, const char* grammar_str, const char* grammar_root) {
|
||||
try {
|
||||
// Initialize grammar sampler directly with the model
|
||||
struct llama_sampler* sampler = llama_sampler_init_grammar(vocab, grammar_str, grammar_root);
|
||||
if (!sampler) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Cast the sampler to a grammar and return it
|
||||
return (struct llama_grammar*)sampler;
|
||||
} catch (const std::exception &err) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void grammar_apply_to_logits(struct llama_grammar* grammar, float* logits, int n_logits) {
|
||||
if (!grammar || !logits || n_logits <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Create token data array for the grammar application
|
||||
llama_token_data* token_data = (llama_token_data*)malloc(n_logits * sizeof(llama_token_data));
|
||||
if (!token_data) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Initialize token data from logits
|
||||
for (int i = 0; i < n_logits; i++) {
|
||||
token_data[i].id = i;
|
||||
token_data[i].logit = logits[i];
|
||||
token_data[i].p = 0.0f;
|
||||
}
|
||||
|
||||
// Create token data array structure
|
||||
llama_token_data_array arr = {
|
||||
.data = token_data,
|
||||
.size = (size_t)n_logits,
|
||||
.sorted = false,
|
||||
.selected = -1
|
||||
};
|
||||
|
||||
// Apply grammar constraints to the token data array
|
||||
llama_grammar_apply_impl(*grammar, &arr);
|
||||
|
||||
// Copy back the modified logits
|
||||
for (int i = 0; i < n_logits; i++) {
|
||||
logits[i] = token_data[i].logit;
|
||||
}
|
||||
|
||||
free(token_data);
|
||||
}
|
||||
|
||||
void grammar_free(struct llama_grammar* grammar) {
|
||||
if (grammar) {
|
||||
// Free the grammar as a sampler
|
||||
llama_sampler_free((struct llama_sampler*)grammar);
|
||||
}
|
||||
}
|
||||
|
||||
struct llama_vocab* vocab_bridge_from_tokens(const char** tokens, int n_tokens) {
|
||||
// Call the C++ function from llama-vocab.cpp
|
||||
return llama_vocab_from_tokens(tokens, n_tokens);
|
||||
}
|
||||
|
||||
void vocab_bridge_free(struct llama_vocab* vocab) {
|
||||
// Call the C++ function from llama-vocab.cpp
|
||||
llama_vocab_free(vocab);
|
||||
}
|
||||
|
||||
} // extern "C"
|
33
llama/grammar_ext.h
vendored
Normal file
33
llama/grammar_ext.h
vendored
Normal file
@ -0,0 +1,33 @@
|
||||
#ifndef GRAMMAR_EXT_H
|
||||
#define GRAMMAR_EXT_H
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Forward declarations
|
||||
struct llama_grammar;
|
||||
struct llama_vocab;
|
||||
|
||||
// Create a new grammar from a string (returns a grammar implemented as a sampler)
|
||||
struct llama_grammar* grammar_create_from_string(const struct llama_vocab* vocab, const char* grammar_str, const char* grammar_root);
|
||||
|
||||
// Apply grammar constraints to logits
|
||||
void grammar_apply_to_logits(struct llama_grammar* grammar, float* logits, int n_logits);
|
||||
|
||||
// Free grammar resources (frees the underlying sampler)
|
||||
void grammar_free(struct llama_grammar* grammar);
|
||||
|
||||
// C wrapper for llama_vocab_from_tokens
|
||||
struct llama_vocab* vocab_bridge_from_tokens(const char** tokens, int n_tokens);
|
||||
|
||||
// C wrapper for llama_vocab_free
|
||||
void vocab_bridge_free(struct llama_vocab* vocab);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // GRAMMAR_EXT_H
|
@ -18,6 +18,7 @@ package llama
|
||||
|
||||
#include "mllama.h"
|
||||
#include "sampling_ext.h"
|
||||
#include "grammar_ext.h"
|
||||
|
||||
extern bool llamaProgressCallback(float progress, void *user_data);
|
||||
extern void llamaLog(int level, char* text, void* user_data);
|
||||
|
117
llama/patches/0019-expose-llama_vocab-from-tokens.patch
Normal file
117
llama/patches/0019-expose-llama_vocab-from-tokens.patch
Normal file
@ -0,0 +1,117 @@
|
||||
From 668a974433edccf2c5fcc2192c39aed601e575f2 Mon Sep 17 00:00:00 2001
|
||||
From: Bruce MacDonald <brucewmacdonald@gmail.com>
|
||||
Date: Thu, 6 Mar 2025 21:07:06 -0800
|
||||
Subject: [PATCH] expose llama_vocab from tokens
|
||||
|
||||
---
|
||||
llama/llama.cpp/src/llama-vocab.cpp | 73 +++++++++++++++++++++++++++++
|
||||
llama/llama.cpp/src/llama-vocab.h | 11 ++++-
|
||||
2 files changed, 83 insertions(+), 1 deletion(-)
|
||||
|
||||
diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp
|
||||
index c7ff28be..ad6e7ad8 100644
|
||||
--- a/llama/llama.cpp/src/llama-vocab.cpp
|
||||
+++ b/llama/llama.cpp/src/llama-vocab.cpp
|
||||
@@ -3253,3 +3253,76 @@ int32_t llama_detokenize(
|
||||
return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
|
||||
}
|
||||
|
||||
+struct llama_vocab *llama_vocab_from_tokens(const char **tokens, int n_tokens)
|
||||
+{
|
||||
+ if (!tokens || n_tokens <= 0)
|
||||
+ {
|
||||
+ return nullptr;
|
||||
+ }
|
||||
+
|
||||
+ try
|
||||
+ {
|
||||
+ // Create a new vocabulary instance
|
||||
+ llama_vocab *vocab = new llama_vocab();
|
||||
+ vocab->pimpl = std::make_unique<llama_vocab::impl>(*vocab);
|
||||
+
|
||||
+ // Resize the token data vectors
|
||||
+ vocab->pimpl->id_to_token.resize(n_tokens);
|
||||
+
|
||||
+ // Create mappings for all tokens
|
||||
+ for (int i = 0; i < n_tokens; i++)
|
||||
+ {
|
||||
+ std::string word = tokens[i];
|
||||
+ if (word.empty())
|
||||
+ {
|
||||
+ word = "[EMPTY_" + std::to_string(i) + "]";
|
||||
+ }
|
||||
+
|
||||
+ // Add to token mappings
|
||||
+ vocab->pimpl->token_to_id[word] = i;
|
||||
+
|
||||
+ // Set up token data
|
||||
+ auto &token_data = vocab->pimpl->id_to_token[i];
|
||||
+ token_data.text = std::move(word);
|
||||
+ token_data.score = 0.0f; // Default score
|
||||
+ token_data.attr = LLAMA_TOKEN_ATTR_NORMAL;
|
||||
+
|
||||
+ // Detect special tokens
|
||||
+ if (word == "<s>" || word == "<bos>")
|
||||
+ {
|
||||
+ vocab->pimpl->special_bos_id = i;
|
||||
+ }
|
||||
+ else if (word == "</s>" || word == "<eos>" || word == "<|endoftext|>")
|
||||
+ {
|
||||
+ vocab->pimpl->special_eos_id = i;
|
||||
+ vocab->pimpl->special_eog_ids.insert(i);
|
||||
+ }
|
||||
+ else if (word == "<unk>")
|
||||
+ {
|
||||
+ vocab->pimpl->special_unk_id = i;
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ // Initialize the token-to-piece cache
|
||||
+ vocab->pimpl->cache_token_to_piece.resize(n_tokens);
|
||||
+ for (int i = 0; i < n_tokens; i++)
|
||||
+ {
|
||||
+ vocab->pimpl->cache_token_to_piece[i] = vocab->pimpl->id_to_token[i].text;
|
||||
+ }
|
||||
+
|
||||
+ return vocab;
|
||||
+ }
|
||||
+ catch (const std::exception &err)
|
||||
+ {
|
||||
+ return nullptr;
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// Helper function to free the vocab
|
||||
+void llama_vocab_free(struct llama_vocab *vocab)
|
||||
+{
|
||||
+ if (vocab)
|
||||
+ {
|
||||
+ delete vocab;
|
||||
+ }
|
||||
+}
|
||||
\ No newline at end of file
|
||||
diff --git a/llama/llama.cpp/src/llama-vocab.h b/llama/llama.cpp/src/llama-vocab.h
|
||||
index 5ce35521..eceb28f3 100644
|
||||
--- a/llama/llama.cpp/src/llama-vocab.h
|
||||
+++ b/llama/llama.cpp/src/llama-vocab.h
|
||||
@@ -119,7 +119,16 @@ struct llama_vocab {
|
||||
|
||||
void print_info() const;
|
||||
|
||||
-private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
};
|
||||
+
|
||||
+// Create a vocabulary from an array of token strings
|
||||
+// tokens: Array of token strings
|
||||
+// n_tokens: Number of tokens in the array
|
||||
+// Returns: A new llama_vocab instance, or nullptr on failure
|
||||
+// The caller is responsible for freeing the vocabulary using llama_vocab_free
|
||||
+LLAMA_API struct llama_vocab * llama_vocab_from_tokens(const char ** tokens, int n_tokens);
|
||||
+
|
||||
+// Free a vocabulary created with llama_vocab_from_tokens
|
||||
+LLAMA_API void llama_vocab_free(struct llama_vocab * vocab);
|
||||
--
|
||||
2.39.3 (Apple Git-145)
|
||||
|
@ -428,7 +428,8 @@ func (s *Server) processBatch() error {
|
||||
|
||||
// sample a token
|
||||
vocabSize := len(logits) / len(options.Outputs)
|
||||
|
||||
// TODO: need access to vocab to apply grammar
|
||||
// token = sampler.Grammar.Apply(logits)
|
||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sample token: %w", err)
|
||||
@ -575,6 +576,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: if grammar is provided, load it
|
||||
// if req.Grammar != "" {
|
||||
// grammar := llama.NewGrammarWithTokens(req.Grammar, "root", s.model.Vocabulary)
|
||||
// }
|
||||
// defer grammar.Close()
|
||||
// sampler := sample.WithGrammar(sample.Greedy(), grammar)
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.NumPredict,
|
||||
stop: req.Stop,
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"golang.org/x/exp/rand"
|
||||
"gonum.org/v1/gonum/stat/sampleuv"
|
||||
)
|
||||
@ -57,12 +58,24 @@ func (s weighted) Sample(logits []float32) (int32, error) {
|
||||
return -1, errors.New("weighted sampler failed, no valid token found")
|
||||
}
|
||||
|
||||
type greedy struct{}
|
||||
type greedy struct {
|
||||
grammar llama.Grammar
|
||||
}
|
||||
|
||||
func Greedy() Sampler {
|
||||
return greedy{}
|
||||
}
|
||||
|
||||
func WithGrammar(s Sampler, grammar llama.Grammar) Sampler {
|
||||
switch t := s.(type) {
|
||||
case greedy:
|
||||
t.grammar = grammar
|
||||
return t
|
||||
default:
|
||||
return s
|
||||
}
|
||||
}
|
||||
|
||||
// Sample returns the index of the maximum value in logits.
|
||||
func (s greedy) Sample(logits []float32) (int32, error) {
|
||||
if len(logits) == 0 {
|
||||
|
Loading…
x
Reference in New Issue
Block a user