From 77409b4e9b8e5bdfffecda46f04eddf2875812ba Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Wed, 5 Nov 2025 17:17:58 -0800 Subject: [PATCH] convert: add deepseek converter This change adds the ability for `ollama create` to convert models that use the DeepSeek2 architecture (specifically DeepSeekV3 and DeepSeek-R1). --- convert/convert.go | 2 + convert/convert_deepseek2.go | 173 +++++++++++++++++++++++++++++++++++ 2 files changed, 175 insertions(+) create mode 100644 convert/convert_deepseek2.go diff --git a/convert/convert.go b/convert/convert.go index 3e98eee1ac..6a54302085 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -206,6 +206,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error { conv = &commandrModel{} case "GptOssForCausalLM": conv = &gptossModel{} + case "DeepseekV3ForCausalLM": + conv = &deepseek2Model{} default: return fmt.Errorf("unsupported architecture %q", p.Architectures[0]) } diff --git a/convert/convert_deepseek2.go b/convert/convert_deepseek2.go new file mode 100644 index 0000000000..aa6203277a --- /dev/null +++ b/convert/convert_deepseek2.go @@ -0,0 +1,173 @@ +package convert + +import ( + "cmp" + "fmt" + "log/slog" + "regexp" + "strconv" + + "github.com/ollama/ollama/fs/ggml" +) + +type deepseek2Model struct { + ModelParameters // architectures, vocab_size + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + HiddenSize uint32 `json:"hidden_size"` + HiddenLayers uint32 `json:"num_hidden_layers"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + RMSNormEPS float32 `json:"rms_norm_eps"` + + RopeTheta float32 `json:"rope_theta"` + QKNopeHeadDim uint32 `json:"qk_nope_head_dim"` + QKRopeHeadDim uint32 `json:"qk_rope_head_dim"` + KVLoraRank uint32 `json:"kv_lora_rank"` + QLoraRank uint32 `json:"q_lora_rank"` + VHeadDim uint32 `json:"v_head_dim"` + + ExpertCount uint32 `json:"n_routed_experts"` + ExpertSharedCount uint32 `json:"n_shared_experts"` + ExpertIntermediateSize uint32 `json:"moe_intermediate_size"` + ExpertUsedCount uint32 `json:"num_experts_per_tok"` + ExpertWeightsNorm bool `json:"norm_topk_prob"` + ExpertWeightsScale float32 `json:"routed_scaling_factor"` + + ScoringFunc string `json:"scoring_func"` + LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"` + + RopeScaling struct { + Factor float32 `json:"factor"` + OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"` + Type string `json:"type"` + MScaleAllDim float32 `json:"mscale_all_dim"` + } `json:"rope_scaling"` + + Architecture string +} + +func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV { + kv := p.ModelParameters.KV(t) + kv["general.architecture"] = "deepseek2" + kv["general.type"] = "model" + kv["deepseek2.block_count"] = p.HiddenLayers + + numHeads := p.NumAttentionHeads + numKVHeads := p.NumKeyValueHeads + + kv["deepseek2.attention.head_count"] = numHeads + kv["deepseek2.attention.head_count_kv"] = numKVHeads + kv["deepseek2.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim + kv["deepseek2.attention.kv_lora_rank"] = p.KVLoraRank + kv["deepseek2.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS + kv["deepseek2.attention.q_lora_rank"] = p.QLoraRank + kv["deepseek2.attention.value_length"] = p.VHeadDim + kv["deepseek2.context_length"] = p.MaxPositionEmbeddings + kv["deepseek2.embedding_length"] = p.HiddenSize + kv["deepseek2.expert_count"] = p.ExpertCount + kv["deepseek2.expert_feed_forward_length"] = p.ExpertIntermediateSize + kv["deepseek2.expert_shared_count"] = p.ExpertSharedCount + + var scoringFunc uint32 + switch p.ScoringFunc { + case "softmax": + // not currently supported in the model, but needed for Deepseek-OCR + scoringFunc = 1 + case "sigmoid": + scoringFunc = 2 + } + kv["deepseek2.expert_gating_func"] = scoringFunc + kv["deepseek2.expert_used_count"] = p.ExpertUsedCount + kv["deepseek2.expert_weights_norm"] = p.ExpertWeightsNorm + kv["deepseek2.expert_weights_scale"] = p.ExpertWeightsScale + kv["deepseek2.feed_forward_length"] = p.IntermediateSize + kv["deepseek2.leading_dense_block_count"] = p.LeadingDenseBlockCount + + kv["deepseek2.rope.dimension_count"] = p.QKRopeHeadDim + kv["deepseek2.rope.freq_base"] = cmp.Or(p.RopeTheta, 10000.0) + kv["deepseek2.rope.scaling.factor"] = p.RopeScaling.Factor + kv["deepseek2.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings + kv["deepseek2.rope.scaling.type"] = p.RopeScaling.Type + kv["deepseek2.rope.scaling.yarn_log_multiplier"] = 0.1 * p.RopeScaling.MScaleAllDim + + kv["tokenizer.ggml.pre"] = "deepseek-v3" + + return kv +} + +func (p *deepseek2Model) Replacements() []string { + return []string{ + "lm_head", "output", + "model.embed_tokens", "token_embd", + "model.norm", "output_norm", + "language_model.", "", + "model.layers", "blk", + "input_layernorm", "attn_norm", + "self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa", + "self_attn.kv_a_layernorm", "attn_kv_a_norm", + "self_attn.kv_b_proj", "attn_kv_b", + "self_attn.q_a_proj", "attn_q_a", + "self_attn.q_a_layernorm", "attn_q_a_norm", + "self_attn.q_b_proj", "attn_q_b", + "self_attn.o_proj", "attn_output", + "post_attention_layernorm", "ffn_norm", + "mlp.shared_experts.down_proj", "ffn_down_shexp", + "mlp.shared_experts.gate_proj", "ffn_gate_shexp", + "mlp.shared_experts.up_proj", "ffn_up_shexp", + "mlp.gate_proj", "ffn_gate", + "mlp.down_proj", "ffn_down", + "mlp.up_proj", "ffn_up", + "mlp.gate.e_score_correction_bias", "exp_probs_b.bias", + "mlp.gate", "ffn_gate_inp", + } +} + +func (p *deepseek2Model) Tensors(s []Tensor) (out []*ggml.Tensor) { + merges := make([]merge, p.HiddenLayers*3) + for i := range p.HiddenLayers { + merges[i*3+0] = merge{ + fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i), + fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i), + } + merges[i*3+1] = merge{ + fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i), + fmt.Sprintf("blk.%d.ffn_up_exps.weight", i), + } + merges[i*3+2] = merge{ + fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i), + fmt.Sprintf("blk.%d.ffn_down_exps.weight", i), + } + } + + skipLayer := func(n string, minValue uint32) bool { + re := regexp.MustCompile(`^blk\.(\d+)`) + matches := re.FindStringSubmatch(n) + if matches == nil { + return false + } + + blkNum, err := strconv.Atoi(matches[1]) + if err != nil { + return false + } + + return uint32(blkNum) >= minValue + } + + out, s = mergeTensors(s, merges...) + for _, t := range s { + // skip any additional layers (such as the Multi-Token Prediction layer) + if skipLayer(t.Name(), p.HiddenLayers) { + slog.Debug("skipping layer", "name", t.Name()) + continue + } + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) + } + return out +}