diff --git a/llama/llama.go b/llama/llama.go index ac2c112c29..88672a0330 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -515,33 +515,34 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, } nChunks := C.mtmd_input_chunks_size(ic) numEmbed := llamaContext.Model().NEmbd() - lastChunkSize := 0 + embed := make([][]float32, 0) for i := range int(nChunks) { chunk := C.mtmd_input_chunks_get(ic, C.size_t(i)) numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk)) - lastChunkSize = numTokens + slog.Debug("chunk tokens", "index", i, "numTokens", numTokens) // Encode the chunk if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) { return nil, errors.New("unable to encode mtmd image chunk") } - } - // Get the embeddings - embed := make([][]float32, lastChunkSize) - embd := C.mtmd_get_output_embd(c.c) - if nil == embd { - return nil, errors.New("failed to get image embedding") - } + // Get the embeddings for this chunk + chunkEmbed := make([][]float32, numTokens) + chunkEmbd := C.mtmd_get_output_embd(c.c) + if nil == chunkEmbd { + continue + } - // Extend the embedding array for each token - s := unsafe.Slice((*float32)(embd), numEmbed*lastChunkSize) - rows := make([]float32, len(s)) - copy(rows, s) - for i := range lastChunkSize { - embed[i] = rows[i*numEmbed : (i+1)*numEmbed] + // Extend the embedding array for each token + s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed) + rows := make([]float32, len(s)) + copy(rows, s) + for i := range numTokens { + chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed] + } + embed = append(embed, chunkEmbed...) } - + slog.Debug("image embeddings", "totalEmbeddings", len(embed)) return embed, nil }