fix conversion

This commit is contained in:
Patrick Devine 2025-03-07 14:06:10 -08:00 committed by Michael Yang
parent 0df1800436
commit c62861f4fa
3 changed files with 57 additions and 42 deletions

View File

@ -190,8 +190,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
conv = &gemmaModel{}
case "Gemma2ForCausalLM":
conv = &gemma2Model{}
case "Gemma3ForConditionalGeneration":
conv = &gemma3Model{}
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
conv = &gemma3Model{Architecture: p.Architectures[0]}
case "Phi3ForCausalLM":
conv = &phi3Model{}
case "Qwen2ForCausalLM":
@ -226,6 +226,9 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
}
switch {
case vocabSize == 0:
slog.Warn("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
vocabSize = len(t.Vocabulary.Tokens)
case vocabSize > len(t.Vocabulary.Tokens):
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
for i := range vocabSize - len(t.Vocabulary.Tokens) {

View File

@ -4,7 +4,13 @@ import "github.com/ollama/ollama/fs/ggml"
type gemma3Model struct {
gemmaModel
TextModel gemma3TextModel `json:"text_config"`
Architecture string
TextModel struct {
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
SlidingWindow uint32 `json:"sliding_window"`
} `json:"text_config"`
VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"` // attention.head_count 16
LayerNormEpsilon float32 `json:"layer_norm_eps"` // attention.layer_norm_epsilon 1e-05
@ -15,49 +21,54 @@ type gemma3Model struct {
NumChannels uint32 `json:"num_channels"` // num_channels 3
PatchSize uint32 `json:"patch_size"` // patch_size 14
} `json:"vision_config"`
}
type gemma3TextModel struct {
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
SlidingWindow uint32 `json:"sliding_window"`
AttentionLogitSoftcap float32 `json:"attn_logit_softcapping"`
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
RopeLocalTheta float32 `json:"rope_local_base_freq"`
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
SlidingWindow uint32 `json:"sliding_window"`
}
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3"
kv["gemma3.context_length"] = p.TextModel.MaxPositionEmbeddings
kv["gemma3.embedding_length"] = p.TextModel.HiddenSize
kv["gemma3.block_count"] = p.TextModel.HiddenLayers
kv["gemma3.text.feed_forward_length"] = p.TextModel.IntermediateSize
kv["gemma3.attention.head_count"] = p.TextModel.NumAttentionHeads
kv["gemma3.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads
kv["gemma3.text.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
kv["gemma3.attention.key_length"] = p.TextModel.HeadDim
kv["gemma3.attention.value_length"] = p.TextModel.HeadDim
kv["gemma3.text.attention.sliding_window"] = p.TextModel.SlidingWindow
kv["gemma3.text.final_logit_softcapping"] = p.TextModel.FinalLogitSoftcap
kv["gemma3.text.rope.local.freq_base"] = p.TextModel.RopeLocalTheta
kv["gemma3.text.rope.global.freq_base"] = p.TextModel.RopeGlobalTheta
kv["gemma3.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["gemma3.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
kv["gemma3.vision.patch_size"] = p.VisionModel.PatchSize
kv["gemma3.vision.num_channels"] = p.VisionModel.NumChannels
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["gemma3.vision.attention.layer_norm_epsilon"] = p.VisionModel.LayerNormEpsilon
switch p.Architecture {
case "Gemma3ForCausalLM":
kv["gemma3.context_length"] = p.MaxPositionEmbeddings
kv["gemma3.attention.head_count"] = p.NumAttentionHeads
kv["gemma3.attention.head_count_kv"] = p.NumKeyValueHeads
kv["gemma3.text.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["gemma3.attention.key_length"] = p.HeadDim
kv["gemma3.attention.value_length"] = p.HeadDim
kv["gemma3.text.attention.sliding_window"] = p.SlidingWindow
kv["gemma3.text.final_logit_softcapping"] = p.FinalLogitSoftcap
kv["gemma3.text.rope.local.freq_base"] = p.RopeLocalTheta
kv["gemma3.text.rope.global.freq_base"] = p.RopeGlobalTheta
kv["gemma3.embedding_length"] = p.HiddenSize
kv["gemma3.block_count"] = p.HiddenLayers
kv["gemma3.text.feed_forward_length"] = p.IntermediateSize
default:
kv["gemma3.embedding_length"] = p.TextModel.HiddenSize
kv["gemma3.block_count"] = p.TextModel.HiddenLayers
kv["gemma3.text.feed_forward_length"] = p.TextModel.IntermediateSize
kv["gemma3.text.attention.sliding_window"] = p.TextModel.SlidingWindow
kv["gemma3.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["gemma3.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
kv["gemma3.vision.patch_size"] = p.VisionModel.PatchSize
kv["gemma3.vision.num_channels"] = p.VisionModel.NumChannels
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["gemma3.vision.attention.layer_norm_epsilon"] = p.VisionModel.LayerNormEpsilon
}
kv["tokenizer.ggml.bos_token_id"] = uint32(2)
kv["tokenizer.ggml.eot_token_id"] = uint32(1)
return kv
}

View File

@ -32,7 +32,8 @@ type TextModel struct {
}
const (
gemma27BLayerCount = 46
gemmaGlobalCacheCount = 6
gemma27BLayerCount = 46
)
const (
@ -55,15 +56,15 @@ func newTextModel(c ml.Config) *TextModel {
Layers: make([]TextLayer, c.Uint("block_count")),
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length")),
attnValLen: int(c.Uint("attention.value_length")),
eps: c.Float("text.attention.layer_norm_rms_epsilon"),
numHeads: int(c.Uint("attention.head_count", 8)),
numKVHeads: int(c.Uint("attention.head_count_kv", 4)),
attnKeyLen: int(c.Uint("attention.key_length", 256)),
attnValLen: int(c.Uint("attention.value_length", 256)),
eps: c.Float("text.attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("text.rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("text.rope.global.freq_base", 1000000.0),
ropeScale: c.Float("text.rope.freq_scale", 1.0),
finalLogitSoftcap: c.Float("text.final_logit_softcapping"),
finalLogitSoftcap: c.Float("text.final_logit_softcapping", 30.0),
},
}
@ -84,7 +85,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
ropeType := uint32(2)
ropeBase := opts.ropeLocalBase
if (layer+1)%6 == 0 {
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = opts.ropeGlobalBase
}
@ -116,7 +117,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.TextOptions.ropeLocalBase
if (layer+1)%6 == 0 {
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = m.TextOptions.ropeGlobalBase
}
@ -184,7 +185,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
// gemma alternates between the sliding window (local) and causal (global)
// kv cache every 6 layers
cacheType := cacheTypeSWA
if (i+1)%6 == 0 {
if (i+1)%gemmaGlobalCacheCount == 0 {
cacheType = cacheTypeCausal
}
cache.SetLayer(i)