From 928911bc683a9234343e2542e1a13564dd0f2684 Mon Sep 17 00:00:00 2001 From: Diego Pereira <309799+dpereira@users.noreply.github.com> Date: Wed, 5 Feb 2025 21:53:33 -0300 Subject: [PATCH] runner: avoid buffer overwrite when generating multiple embeddings (#8714) Shield the code processing the embedding result from subsequent calls that may overwrite the same buffer to process a second input when retrieving model embeddings. --- llama/llama.go | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index 1d4513e31..a20f23578 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -199,21 +199,25 @@ func (c *Context) KvCacheDefrag() { // Get the embeddings for a sequence id func (c *Context) GetEmbeddingsSeq(seqId int) []float32 { - embeddings := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId))) - if embeddings == nil { + e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId))) + if e == nil { return nil } - return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd()) + embeddings := make([]float32, c.Model().NEmbd()) + _ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd())) + return embeddings } func (c *Context) GetEmbeddingsIth(i int) []float32 { - embeddings := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i))) - if embeddings == nil { + e := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i))) + if e == nil { return nil } - return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd()) + embeddings := make([]float32, c.Model().NEmbd()) + _ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd())) + return embeddings } type ModelParams struct {