diff --git a/llama/llama.go b/llama/llama.go index 1d4513e31..a20f23578 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -199,21 +199,25 @@ func (c *Context) KvCacheDefrag() { // Get the embeddings for a sequence id func (c *Context) GetEmbeddingsSeq(seqId int) []float32 { - embeddings := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId))) - if embeddings == nil { + e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId))) + if e == nil { return nil } - return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd()) + embeddings := make([]float32, c.Model().NEmbd()) + _ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd())) + return embeddings } func (c *Context) GetEmbeddingsIth(i int) []float32 { - embeddings := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i))) - if embeddings == nil { + e := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i))) + if e == nil { return nil } - return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd()) + embeddings := make([]float32, c.Model().NEmbd()) + _ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd())) + return embeddings } type ModelParams struct {