From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 From: Michael Yang <mxyng@pm.me> Date: Mon, 16 Sep 2024 15:53:14 -0700 Subject: [PATCH] embeddings --- src/llama-context.cpp | 2 +- src/llama.cpp | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 38a55fb2..b9c4a5bf 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -475,7 +475,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) { const auto n_embd = hparams.n_embd; // TODO: use a per-batch flag for logits presence instead - const bool has_logits = !cparams.embeddings; + const bool has_logits = cparams.causal_attn; const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0; diff --git a/src/llama.cpp b/src/llama.cpp index ea78ea48..4eb3f6b9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -10876,7 +10876,6 @@ static int llama_decode_internal( res = nullptr; embd = nullptr; } else if (cparams.embeddings) { - res = nullptr; // do not extract logits for embedding case embd = nullptr; for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) { if (strcmp(ggml_graph_node(gf, i)->name, "result_embd_pooled") == 0) { @@ -10884,12 +10883,15 @@ static int llama_decode_internal( break; } } - GGML_ASSERT(embd != nullptr && "missing embeddings tensor"); } else { embd = nullptr; // do not extract embeddings when not needed GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor"); } + if (!cparams.causal_attn) { + res = nullptr; // do not extract logits when not needed + } + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);