/** * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file * * MIT License * * Copyright (c) 2023-2024 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ #pragma once #include "llama.h" #include #include // very similar to llama_batch, // but has more metadata about sequences struct llama_ubatch { bool equal_seqs; // TODO: whole_seqs for embeddings? uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) uint32_t n_seq_tokens; // tokens per sequence uint32_t n_seqs; llama_token * token; // [n_tokens] float * embd; // [n_embd, n_tokens] llama_pos * pos; // [n_tokens] int32_t * n_seq_id; // [n_seqs] llama_seq_id ** seq_id; // [n_seqs] int8_t * output; // [n_tokens] }; struct llama_sbatch_seq { int32_t n_seq_id; llama_seq_id * seq_id; size_t offset; size_t length; }; // sequence-length-aware batch splitting struct llama_sbatch { // tokens left in this batch size_t n_tokens; size_t n_embd; bool logits_all; // TODO: remove once lctx.logits_all is removed too // sorted indices into the batch std::vector ids; // batch indices of the output std::vector out_ids; std::vector seq; const llama_batch * batch = nullptr; // buffers for the ubatch std::vector ubatch_token; std::vector ubatch_embd; std::vector ubatch_pos; std::vector ubatch_n_seq_id; std::vector ubatch_seq_id; std::vector ubatch_output; llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false); void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length); // simple split, unknown number of sequences of unequal lengths llama_ubatch split_simple(size_t n_ubatch); // make batches of equal-length sequences llama_ubatch split_equal(size_t n_ubatch); // sequence-wise split llama_ubatch split_seq(size_t n_ubatch); void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); }; // temporary allocate memory for the input batch if needed struct llama_batch_allocr { struct llama_batch batch; std::array seq_id_0 = { 0 }; // default sequence id std::vector pos; std::vector n_seq_id; std::vector seq_id; std::vector logits; // optionally fulfill the batch returned by llama_batch_get_one llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); };