From 63e6509ec0456f6dbfc303aead9ed08227af4118 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Thu, 20 Mar 2025 15:37:21 -0700 Subject: [PATCH] vision conversion --- convert/convert_mistral.go | 63 +++++++++++++++++++++----------------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/convert/convert_mistral.go b/convert/convert_mistral.go index ffe6f79da..99032b51c 100644 --- a/convert/convert_mistral.go +++ b/convert/convert_mistral.go @@ -13,10 +13,10 @@ import ( type mistral3Model struct { ModelParameters - // ImageTokenIndex uint32 `json:"image_token_index"` - // SpatialMergeSize uint32 `json:"spatial_merge_size"` - // VisionFeatureLayer int32 `json:"vision_feature_layer"` - TextModel struct { + ImageTokenIndex uint32 `json:"image_token_index"` + SpatialMergeSize uint32 `json:"spatial_merge_size"` + VisionFeatureLayer int32 `json:"vision_feature_layer"` + TextModel struct { NumHiddenLayers uint32 `json:"num_hidden_layers"` MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` HiddenSize uint32 `json:"hidden_size"` @@ -30,20 +30,20 @@ type mistral3Model struct { HiddenAct string `json:"hidden_act"` VocabSize uint32 `json:"vocab_size"` } `json:"text_config"` - // VisionModel struct { - // NumAttentionHeads uint32 `json:"num_attention_heads"` - // NumHiddenLayers uint32 `json:"num_hidden_layers"` - // HiddenSize uint32 `json:"hidden_size"` - // IntermediateSize uint32 `json:"intermediate_size"` - // ImageSize uint32 `json:"image_size"` - // NumChannels uint32 `json:"num_channels"` - // PatchSize uint32 `json:"patch_size"` - // HeadDim uint32 `json:"head_dim"` - // HiddenAct string `json:"hidden_act"` - // RopeTheta float32 `json:"rope_theta"` - // } `json:"vision_config"` - // MultiModalProjectorBias bool `json:"multimodal_projector_bias"` - // ProjectorHiddenAct string `json:"projector_hidden_act"` + VisionModel struct { + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumHiddenLayers uint32 `json:"num_hidden_layers"` + HiddenSize uint32 `json:"hidden_size"` + IntermediateSize uint32 `json:"intermediate_size"` + ImageSize uint32 `json:"image_size"` + NumChannels uint32 `json:"num_channels"` + PatchSize uint32 `json:"patch_size"` + HeadDim uint32 `json:"head_dim"` + HiddenAct string `json:"hidden_act"` + RopeTheta float32 `json:"rope_theta"` + } `json:"vision_config"` + MultiModalProjectorBias bool `json:"multimodal_projector_bias"` + ProjectorHiddenAct string `json:"projector_hidden_act"` } func (p *mistral3Model) KV(t *Tokenizer) ggml.KV { @@ -64,19 +64,26 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV { kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta + // Vision configuration + kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers + kv["mistral3.vision.embedding_length"] = p.VisionModel.HiddenSize + kv["mistral3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize + kv["mistral3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads + kv["mistral3.vision.attention.key_length"] = p.VisionModel.HeadDim + kv["mistral3.vision.image_size"] = p.VisionModel.ImageSize + kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize + kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels + kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta + // Multimodal configuration - // kv["mistral3.image_token_index"] = p.ImageTokenIndex - // kv["mistral3.spatial_merge_size"] = p.SpatialMergeSize + kv["mistral3.image_token_index"] = p.ImageTokenIndex + kv["mistral3.spatial_merge_size"] = p.SpatialMergeSize - // if p.VisionFeatureLayer != 0 { - // kv["mistral3.vision_feature_layer"] = p.VisionFeatureLayer - // } + kv["mistral3.mm.projector_bias"] = p.MultiModalProjectorBias - // kv["mistral3.mm.projector_bias"] = p.MultiModalProjectorBias - - // if p.ProjectorHiddenAct != "" { - // kv["mistral3.mm.projector_hidden_act"] = p.ProjectorHiddenAct - // } + if p.ProjectorHiddenAct != "" { + kv["mistral3.mm.projector_hidden_act"] = p.ProjectorHiddenAct + } return kv }