llama: shim to using ollama_vocab instead of copy

This commit is contained in:
ParthSareen 2025-04-03 14:58:49 -07:00
parent 377fa9d2ba
commit b38db7166d
9 changed files with 102 additions and 1480 deletions

1261
llama/grammar.cpp vendored

File diff suppressed because it is too large Load Diff

171
llama/grammar.h vendored
View File

@ -1,171 +0,0 @@
#pragma once
#include "llama.h"
#include <map>
#include <string>
#include <vector>
struct ollama_vocab {
std::map<std::string, uint32_t> symbol_ids;
std::map<uint32_t, std::string> token_to_piece;
uint32_t eog_token;
void add_symbol_id(const std::string & symbol, uint32_t id);
void add_token_piece(uint32_t token, const std::string & piece);
void set_eog_token(uint32_t token);
};
// grammar element type
enum gretype {
// end of rule definition
GRETYPE_END = 0,
// start of alternate definition for rule
GRETYPE_ALT = 1,
// non-terminal element: reference to rule
GRETYPE_RULE_REF = 2,
// terminal element: character (code point)
GRETYPE_CHAR = 3,
// inverse char(s) ([^a], [^a-b] [^abc])
GRETYPE_CHAR_NOT = 4,
// modifies a preceding GRETYPE_CHAR or GRETYPE_CHAR_ALT to
// be an inclusive range ([a-z])
GRETYPE_CHAR_RNG_UPPER = 5,
// modifies a preceding GRETYPE_CHAR or
// GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
GRETYPE_CHAR_ALT = 6,
// any character (.)
GRETYPE_CHAR_ANY = 7,
};
typedef struct grammar_element {
enum gretype type;
uint32_t value; // Unicode code point or rule ID
} grammar_element;
struct partial_utf8 {
uint32_t value; // bit value so far (unshifted)
int n_remain; // num bytes remaining; -1 indicates invalid sequence
};
struct grammar_candidate {
size_t index;
const uint32_t * code_points;
partial_utf8 utf8_state;
};
using grammar_rule = std::vector< grammar_element>;
using grammar_stack = std::vector<const grammar_element *>;
using grammar_rules = std::vector<grammar_rule>;
using grammar_stacks = std::vector<grammar_stack>;
using grammar_candidates = std::vector<grammar_candidate>;
// TODO: remove, needed for tests atm
const grammar_rules & grammar_get_rules (const struct grammar * grammar);
grammar_stacks & grammar_get_stacks( struct grammar * grammar);
// takes a set of possible pushdown stacks on a grammar, which are required to
// be positioned at a character range (see `grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those
// positions
void grammar_accept(struct grammar * grammar, uint32_t chr);
std::vector<grammar_candidate> grammar_reject_candidates_for_stack(
const grammar_rules & rules,
const grammar_stack & stack,
const grammar_candidates & candidates);
struct grammar_parser {
std::map<std::string, uint32_t> symbol_ids;
grammar_rules rules;
grammar_stack c_rules() const;
uint32_t get_symbol_id(const char * src, size_t len);
uint32_t generate_symbol_id(const std::string & base_name);
void add_rule(uint32_t rule_id, const grammar_rule & rule);
const char * parse_alternates(
const char * src,
const std::string & rule_name,
uint32_t rule_id,
bool is_nested);
const char * parse_sequence(
const char * src,
const std::string & rule_name,
grammar_rule & rule,
bool is_nested);
const char * parse_rule(const char * src);
bool parse(const char * src);
void print(FILE * file);
};
struct grammar {
// note: allow null vocab for testing (not great)
ollama_vocab * vocab;
const grammar_rules rules; // TODO: shared ptr
grammar_stacks stacks;
// buffer for partially generated UTF-8 sequence from accepted tokens
partial_utf8 utf8_state;
// lazy grammars wait for trigger words or tokens before constraining the sampling.
// we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
// (useful e.g. for tool_choice=required)
bool lazy = false;
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
std::vector<std::string> trigger_words;
};
//
// internal API
//
// note: needed for tests (not great)
struct grammar * grammar_init_impl(
struct ollama_vocab * ollama_vocab,
const grammar_element ** rules,
size_t n_rules,
size_t start_rule_index);
struct grammar * grammar_init_impl(
struct ollama_vocab * ollama_vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,
const char ** trigger_words,
size_t num_trigger_words,
const llama_token * trigger_tokens,
size_t num_trigger_tokens);
void grammar_free_impl(struct grammar * grammar);
struct grammar * grammar_clone_impl(const struct grammar & grammar);
// TODO(parthsareen): move the API below as member functions of grammar
void grammar_apply_impl(
const struct grammar & grammar,
llama_token_data_array * cur_p);
void grammar_accept_impl(
struct grammar & grammar,
llama_token token);
void grammar_accept_str(
struct grammar & grammar,
const std::string & piece);

View File

@ -907,6 +907,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
struct ollama_vocab * ollama_vocab,
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index) {
@ -962,6 +963,7 @@ struct llama_grammar * llama_grammar_init_impl(
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar {
vocab,
ollama_vocab,
std::move(vec_rules),
std::move(stacks),
/* .partial_utf8 = */ {},
@ -975,6 +977,7 @@ struct llama_grammar * llama_grammar_init_impl(
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
struct ollama_vocab * ollama_vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,
@ -1069,6 +1072,7 @@ struct llama_grammar * llama_grammar_init_impl(
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar {
vocab,
ollama_vocab,
std::move(vec_rules),
std::move(stacks),
/* .partial_utf8 = */ {},
@ -1091,6 +1095,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
llama_grammar * result = new llama_grammar {
grammar.vocab,
grammar.ollama_vocab,
grammar.rules,
grammar.stacks,
grammar.partial_utf8,
@ -1118,7 +1123,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
}
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
GGML_ASSERT(grammar.vocab != nullptr);
GGML_ASSERT(!(grammar.vocab == nullptr && grammar.ollama_vocab == nullptr));
if (grammar.awaiting_trigger) {
return;
@ -1140,9 +1145,21 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
for (size_t i = 0; i < cur_p->size; ++i) {
const llama_token id = cur_p->data[i].id;
const std::string & piece = grammar.vocab->token_to_piece(id);
std::string piece;
if (grammar.ollama_vocab) {
piece = grammar.ollama_vocab->token_to_piece(id);
} else {
piece = grammar.vocab->token_to_piece(id);
}
if (grammar.vocab->is_eog(id)) {
bool is_eog = false;
if (grammar.ollama_vocab) {
is_eog = grammar.ollama_vocab->is_eog(id);
} else {
is_eog = grammar.vocab->is_eog(id);
}
if (is_eog) {
if (!allow_eog) {
cur_p->data[i].logit = -INFINITY;
}
@ -1161,9 +1178,14 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
}
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
GGML_ASSERT(grammar.vocab != nullptr);
GGML_ASSERT(!(grammar.vocab == nullptr && grammar.ollama_vocab == nullptr));
const auto & piece = grammar.vocab->token_to_piece(token);
std::string piece;
if (grammar.ollama_vocab) {
piece = grammar.ollama_vocab->token_to_piece(token);
} else {
piece = grammar.vocab->token_to_piece(token);
}
if (grammar.awaiting_trigger) {
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
@ -1191,13 +1213,24 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
}
}
if (grammar.vocab->is_eog(token)) {
for (const auto & stack : grammar.stacks) {
if (stack.empty()) {
return;
if (grammar.ollama_vocab) {
if (grammar.ollama_vocab->is_eog(token)) {
for (const auto & stack : grammar.stacks) {
if (stack.empty()) {
return;
}
}
GGML_ABORT("grammar error: end of grammar token received but grammar stack is not empty");
}
} else {
if (grammar.vocab->is_eog(token)) {
for (const auto & stack : grammar.stacks) {
if (stack.empty()) {
return;
}
}
GGML_ABORT("grammar error: end of grammar token received but grammar stack is not empty");
}
GGML_ABORT("fatal error");
}
llama_grammar_accept_str(grammar, piece);
@ -1217,3 +1250,20 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
}
}
const std::string & ollama_vocab::token_to_piece(uint32_t token) {
return token_to_piece_map[token];
}
bool ollama_vocab::is_eog(uint32_t token) {
return token == eog_token;
}
void ollama_vocab::add_token_piece(uint32_t token, const std::string & piece) {
token_to_piece_map[token] = piece;
}
void ollama_vocab::set_eog_token(uint32_t token) {
eog_token = token;
}

View File

@ -7,6 +7,17 @@
#include <vector>
struct llama_vocab;
struct ollama_vocab {
std::map<uint32_t, std::string> token_to_piece_map;
uint32_t eog_token;
void add_token_piece(uint32_t token, const std::string & piece);
void set_eog_token(uint32_t token);
const std::string & token_to_piece(uint32_t token);
bool is_eog(uint32_t token);
};
// grammar element type
enum llama_gretype {
@ -108,6 +119,7 @@ struct llama_grammar_parser {
struct llama_grammar {
// note: allow null vocab for testing (not great)
const llama_vocab * vocab;
ollama_vocab * ollama_vocab;
const llama_grammar_rules rules; // TODO: shared ptr
llama_grammar_stacks stacks;
@ -132,12 +144,14 @@ struct llama_grammar {
// note: needed for tests (not great)
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
struct ollama_vocab * ollama_vocab,
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index);
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab * vocab,
struct ollama_vocab * ollama_vocab,
const char * grammar_str,
const char * grammar_root,
bool lazy,

View File

@ -1461,7 +1461,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
for (auto & word : ctx->grammar->trigger_words) {
trigger_words.push_back(word.c_str());
}
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, nullptr, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
ctx->grammar->lazy, trigger_words.data(), trigger_words.size(),
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
@ -1524,7 +1524,7 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
/* .vocab = */ vocab,
/* .grammar_str = */ grammar_str,
/* .grammar_root = */ grammar_root,
/* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens),
/* .grammar = */ llama_grammar_init_impl(vocab, nullptr, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens),
};
} else {
*ctx = {

View File

@ -680,26 +680,20 @@ type TokenData struct {
}
type Grammar struct {
c *C.struct_grammar
}
func (g *Grammar) AddSymbol(symbol string, id uint32) {
cSymbol := C.CString(symbol)
defer C.free(unsafe.Pointer(cSymbol))
C.grammar_add_symbol_id(g.c, cSymbol, C.uint32_t(id))
c *C.struct_llama_grammar
}
func (g *Grammar) AddTokenPiece(token uint32, piece string) {
cPiece := C.CString(piece)
defer C.free(unsafe.Pointer(cPiece))
C.grammar_add_token_piece(g.c, C.uint32_t(token), cPiece)
C.ollama_vocab_add_token_piece(g.c, C.uint32_t(token), cPiece)
}
func (g *Grammar) SetEOGToken(token uint32) {
C.grammar_set_eog_token(g.c, C.uint32_t(token))
C.ollama_vocab_set_eog_token(g.c, C.uint32_t(token))
}
func InitGrammarChain(grammar string) *Grammar {
func LoadGrammar(grammar string) *Grammar {
cGrammar := C.CString(grammar)
defer C.free(unsafe.Pointer(cGrammar))

View File

@ -2,10 +2,10 @@
#include "sampling.h"
#include "sampling_ext.h"
#include "json-schema-to-grammar.h"
#include "grammar.h"
#include "llama.h"
#include "llama-model.h"
#include "llama-model-loader.h"
#include "llama-grammar.h"
struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) {
try {
@ -70,7 +70,7 @@ int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len)
}
struct grammar *grammar_init(char* grammar) {
struct llama_grammar *grammar_init(char* grammar) {
if (grammar == nullptr) {
LLAMA_LOG_ERROR("%s: null grammar input\n", __func__);
return nullptr;
@ -88,7 +88,7 @@ struct grammar *grammar_init(char* grammar) {
// Initialize grammar with the vocab
struct grammar *g = grammar_init_impl(vocab, grammar, "root", false, nullptr, 0, nullptr, 0);
struct llama_grammar *g = llama_grammar_init_impl(nullptr, vocab, grammar, "root", false, nullptr, 0, nullptr, 0);
if (g == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize grammar\n", __func__);
delete vocab;
@ -103,31 +103,28 @@ struct grammar *grammar_init(char* grammar) {
}
}
void grammar_free(struct grammar *g) {
void grammar_free(struct llama_grammar *g) {
if (g != nullptr) {
if (g->vocab != nullptr) {
delete g->vocab;
}
grammar_free_impl(g);
llama_grammar_free_impl(g);
}
}
void grammar_apply(struct grammar *g, struct llama_token_data_array *tokens) {
grammar_apply_impl(*g, tokens);
void grammar_apply(struct llama_grammar *g, struct llama_token_data_array *tokens) {
llama_grammar_apply_impl(*g, tokens);
}
void grammar_accept(struct grammar *g, llama_token id) {
grammar_accept_impl(*g, id);
void grammar_accept(struct llama_grammar *g, llama_token id) {
llama_grammar_accept_impl(*g, id);
}
void grammar_add_symbol_id(struct grammar *g, const char *symbol, uint32_t id) {
g->vocab->add_symbol_id(symbol, id);
void ollama_vocab_add_token_piece(struct llama_grammar *g, uint32_t token, const char *piece) {
g->ollama_vocab->add_token_piece(token, piece);
}
void grammar_add_token_piece(struct grammar *g, uint32_t token, const char *piece) {
g->vocab->add_token_piece(token, piece);
}
void grammar_set_eog_token(struct grammar *g, uint32_t token) {
g->vocab->set_eog_token(token);
void ollama_vocab_set_eog_token(struct llama_grammar *g, uint32_t token) {
g->ollama_vocab->set_eog_token(token);
}

14
llama/sampling_ext.h vendored
View File

@ -37,13 +37,13 @@ extern "C"
struct ollama_vocab;
struct grammar *grammar_init(char* grammar);
void grammar_free(struct grammar *g);
void grammar_apply(struct grammar *g, struct llama_token_data_array *tokens);
void grammar_accept(struct grammar *g, llama_token id);
void grammar_add_symbol_id(struct grammar *g, const char *symbol, uint32_t id);
void grammar_add_token_piece(struct grammar *g, uint32_t token, const char *piece);
void grammar_set_eog_token(struct grammar *g, uint32_t token);
struct llama_grammar *grammar_init(char* grammar);
void grammar_free(struct llama_grammar *g);
void grammar_apply(struct llama_grammar *g, struct llama_token_data_array *tokens);
void grammar_accept(struct llama_grammar *g, llama_token id);
void ollama_vocab_add_token_piece(struct llama_grammar *g, uint32_t token, const char *piece);
void ollama_vocab_set_eog_token(struct llama_grammar *g, uint32_t token);
#ifdef __cplusplus
}

View File

@ -170,13 +170,12 @@ type Grammar struct {
}
func NewGrammar(vocab *model.Vocabulary, grammarStr string) (*Grammar, error) {
grammar := llama.InitGrammarChain(grammarStr)
grammar := llama.LoadGrammar(grammarStr)
if grammar == nil {
return nil, errors.New("sample: failed to initialize grammar")
}
for _, s := range vocab.Values {
id := vocab.Encode(s)
grammar.AddSymbol(s, uint32(id))
grammar.AddTokenPiece(uint32(id), s)
}
grammar.SetEOGToken(uint32(vocab.EOS))