fix: Update llama.go to use mtmd instead of clip/llava

It's _very_ possible that this is broken!

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart
2025-06-24 17:48:31 -06:00
parent fa54a3cf3a
commit 3d70237fd1
2 changed files with 56 additions and 42 deletions

View File

@@ -13,8 +13,7 @@ package llama
#include <stdlib.h> #include <stdlib.h>
#include "ggml.h" #include "ggml.h"
#include "llama.h" #include "llama.h"
#include "clip.h" #include "mtmd.h"
#include "llava.h"
#include "gguf.h" #include "gguf.h"
#include "sampling_ext.h" #include "sampling_ext.h"
@@ -148,27 +147,23 @@ func (c *Context) Model() *Model {
} }
func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) { func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
C.llama_kv_self_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta)) C.llama_memory_seq_add(C.llama_get_memory(c.c), C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
} }
func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool { func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
return bool(C.llama_kv_self_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1))) return bool(C.llama_memory_seq_rm(C.llama_get_memory(c.c), C.int(seqId), C.int(p0), C.int(p1)))
} }
func (c *Context) KvCacheSeqCp(srcSeqId int, dstSeqId int, p0 int, p1 int) { func (c *Context) KvCacheSeqCp(srcSeqId int, dstSeqId int, p0 int, p1 int) {
C.llama_kv_self_seq_cp(c.c, C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1)) C.llama_memory_seq_cp(C.llama_get_memory(c.c), C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
} }
func (c *Context) KvCacheClear() { func (c *Context) KvCacheClear() {
C.llama_kv_self_clear(c.c) C.llama_memory_clear(C.llama_get_memory(c.c), true)
}
func (c *Context) KvCacheDefrag() {
C.llama_kv_self_defrag(c.c)
} }
func (c *Context) KvCacheCanShift() bool { func (c *Context) KvCacheCanShift() bool {
return bool(C.llama_kv_self_can_shift(c.c)) return bool(C.llama_memory_can_shift(C.llama_get_memory(c.c)))
} }
// Get the embeddings for a sequence id // Get the embeddings for a sequence id
@@ -460,52 +455,71 @@ func (m *Model) NEmbd() int {
} }
// vision processing // vision processing
type ClipContext struct { type MtmdContext struct {
c *C.struct_clip_ctx c *C.struct_mtmd_context
} }
func NewClipContext(llamaContext *Context, modelPath string) (*ClipContext, error) { func NewMtmdContext(llamaContext *Context, modelPath string) (*MtmdContext, error) {
mp := C.CString(modelPath) mp := C.CString(modelPath)
defer C.free(unsafe.Pointer(mp)) defer C.free(unsafe.Pointer(mp))
c := C.clip_model_load(mp, 1) // TODO: Support non-default params
cp := C.mtmd_context_params_default()
// NOTE: The model and projector embedding lengths are checked during init
c := C.mtmd_init_from_file(mp, C.llama_get_model(llamaContext.c), cp)
if c == nil { if c == nil {
return nil, fmt.Errorf("unable to load clip model: %v", modelPath) return nil, fmt.Errorf("unable to load mmtd model: %v", modelPath)
} }
projEmbedSize := int(C.clip_n_mmproj_embd(c)) return &MtmdContext{c: c}, nil
modelEmbedSize := llamaContext.Model().NEmbd()
if projEmbedSize != modelEmbedSize {
return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
}
return &ClipContext{c: c}, nil
} }
func (c *ClipContext) Free() { func (c *MtmdContext) Free() {
C.clip_free(c.c) C.mtmd_free(c.c)
} }
func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, error) { func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, error) {
l := C.llava_image_embed_make_with_bytes(c.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data))) // Initialize the input chunks pointer
if l == nil { ic := C.mtmd_input_chunks_init()
return nil, errors.New("unable to make llava embedding from image") defer C.mtmd_input_chunks_free(ic)
// Initialize an empty text prompt so we can tokenize
it := C.mtmd_input_text_init(C.mtmd_default_marker(), true, true)
defer C.mtmd_input_text_free(it)
// Initialize a bitmap with the image data
bm := C.mtmd_bitmap_init(C.uint32_t(len(data)/3), C.uint32_t(1), (*C.uchar)(unsafe.Pointer(&data[0])))
defer C.mtmd_bitmap_free(bm)
// Tokenize the image
if C.int32_t(0) != C.mtmd_tokenize(c.c, ic, it, &bm, 1) {
return nil, errors.New("unable to tokenize mtmd embedding from image")
}
nChunks := C.mtmd_input_chunks_size(ic)
if nChunks != 1 {
return nil, errors.New("image-only mtmd input tokenized to multiple chunks!")
}
chunk := C.mtmd_input_chunks_get(ic, 0)
// Encode the chunk
if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) {
return nil, errors.New("unable to encode mtmd image chunk")
} }
numTokens := int(l.n_image_pos) // Get the embedding
embd := C.mtmd_get_output_embd(c.c)
// Copy embeddings over to go slice
numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk))
numEmbed := llamaContext.Model().NEmbd() numEmbed := llamaContext.Model().NEmbd()
s := unsafe.Slice((*float32)(embd), numEmbed*numTokens)
s := unsafe.Slice((*float32)(l.embed), numEmbed*numTokens)
embed := make([][]float32, numTokens) embed := make([][]float32, numTokens)
rows := make([]float32, len(s)) rows := make([]float32, len(s))
copy(rows, s) copy(rows, s)
for i := range embed { for i := range embed {
embed[i] = rows[i*numEmbed : (i+1)*numEmbed] embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
} }
C.llava_image_embed_free(l)
return embed, nil return embed, nil
} }

View File

@@ -17,7 +17,7 @@ type ImageContext struct {
// mu is required to be held when generating embeddings or accessing the cache // mu is required to be held when generating embeddings or accessing the cache
mu sync.Mutex mu sync.Mutex
clip *llama.ClipContext mtmd *llama.MtmdContext
// cache of images to embeddings // cache of images to embeddings
images []imageCache images []imageCache
@@ -32,7 +32,7 @@ func NewImageContext(llamaContext *llama.Context, modelPath string) (*ImageConte
var c ImageContext var c ImageContext
if arch == "clip" { if arch == "clip" {
c.clip, err = llama.NewClipContext(llamaContext, modelPath) c.mtmd, err = llama.NewMtmdContext(llamaContext, modelPath)
} else { } else {
return nil, fmt.Errorf("unknown vision model architecture: %s", arch) return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
} }
@@ -51,8 +51,8 @@ func (c *ImageContext) Free(modelPath string) {
return return
} }
if c.clip != nil { if c.mtmd != nil {
c.clip.Free() c.mtmd.Free()
} }
} }
@@ -72,8 +72,8 @@ func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte) ([][]f
embed, err := c.findImage(hash) embed, err := c.findImage(hash)
if err != nil { if err != nil {
if c.clip != nil { if c.mtmd != nil {
embed, err = c.clip.NewEmbed(llamaContext, data) embed, err = c.mtmd.NewEmbed(llamaContext, data)
if err != nil { if err != nil {
return nil, err return nil, err
} }