llama: fix naming in grammar

This commit is contained in:
ParthSareen 2025-04-03 11:02:13 -07:00
parent a0022981c7
commit 106592820d
3 changed files with 12 additions and 15 deletions

18
llama/grammar.cpp vendored
View File

@ -871,7 +871,7 @@ grammar_candidates grammar_reject_candidates_for_stack(
if (stack.empty()) {
for (const auto & tok : candidates) {
if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
if (*tok.code_points != 0 || tok.utf8_state.n_remain != 0) {
rejects.push_back(tok);
}
}
@ -887,12 +887,12 @@ grammar_candidates grammar_reject_candidates_for_stack(
if (*tok.code_points == 0) {
// reached end of full codepoints in token, reject iff it ended in a partial sequence
// that cannot satisfy this position in grammar
if (tok.partial_utf8.n_remain != 0 &&
!grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
if (tok.utf8_state.n_remain != 0 &&
!grammar_match_partial_char(stack_pos, tok.utf8_state)) {
rejects.push_back(tok);
}
} else if (grammar_match_char(stack_pos, *tok.code_points).first) {
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.utf8_state });
} else {
rejects.push_back(tok);
}
@ -910,7 +910,7 @@ grammar_candidates grammar_reject_candidates_for_stack(
auto next_rejects = grammar_reject_candidates(rules, next_stacks, next_candidates);
for (const auto & tok : next_rejects) {
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
rejects.push_back({ tok.index, tok.code_points - 1, tok.utf8_state });
}
return rejects;
@ -1123,7 +1123,7 @@ struct grammar * grammar_clone_impl(const struct grammar & g) {
g.vocab,
g.rules,
g.stacks,
g.partial_utf8,
g.utf8_state,
g.lazy,
g.awaiting_trigger,
g.trigger_buffer,
@ -1179,7 +1179,7 @@ void grammar_apply_impl(const struct grammar & grammar, llama_token_data_array *
} else if (piece.empty() || piece[0] == 0) {
cur_p->data[i].logit = -INFINITY;
} else {
candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
candidates_decoded.push_back(decode_utf8(piece, grammar.utf8_state));
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
}
}
@ -1235,14 +1235,14 @@ void grammar_accept_impl(struct grammar & grammar, llama_token token) {
void grammar_accept_str(struct grammar & grammar, const std::string & piece) {
// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
const auto decoded = decode_utf8(piece, grammar.utf8_state);
const auto & code_points = decoded.first;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
grammar_accept(&grammar, *it);
}
grammar.partial_utf8 = decoded.second;
grammar.utf8_state = decoded.second;
if (grammar.stacks.empty()) {
throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
}

4
llama/grammar.h vendored
View File

@ -58,7 +58,7 @@ struct partial_utf8 {
struct grammar_candidate {
size_t index;
const uint32_t * code_points;
partial_utf8 partial_utf8;
partial_utf8 utf8_state;
};
using grammar_rule = std::vector< grammar_element>;
@ -120,7 +120,7 @@ struct grammar {
grammar_stacks stacks;
// buffer for partially generated UTF-8 sequence from accepted tokens
partial_utf8 partial_utf8;
partial_utf8 utf8_state;
// lazy grammars wait for trigger words or tokens before constraining the sampling.
// we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens.

View File

@ -655,8 +655,7 @@ string ::=
)* "\""
number ::= "-"? ("0" | [1-9] [0-9]*) ("." [0-9]+)? ([eE] [-+]? [0-9]+)?
ws ::= [ \t \n \r]*
s ::= [ \n \t]
t ::= [ \t \r]*`
s ::= [ \n \t]`
const maxBufferSize = 512 * format.KiloByte
@ -694,8 +693,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
break
case `"json"`:
req.Grammar = grammarJSON
slog.Info("using JSON grammar")
slog.Info(req.Grammar)
default:
if req.Format[0] != '{' {
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)