package llama import ( "cmp" "fmt" "math" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn/fast" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" ) type Options struct { hiddenSize, numHeads, numKVHeads int headDim, ropeDim int eps, ropeBase, ropeScale float32 } type Model struct { model.Base model.BytePairEncoding TokenEmbedding *nn.Embedding `gguf:"token_embd"` Layers []Layer `gguf:"blk"` OutputNorm *nn.RMSNorm `gguf:"output_norm"` Output *nn.Linear `gguf:"output,alt:token_embd"` *Options } func New(c fs.Config) (model.Model, error) { // This model currently only supports the gpt2 tokenizer if c.String("tokenizer.ggml.model") == "llama" { return nil, fmt.Errorf("unsupported tokenizer: llama") } // Best effort detection of library/deepseek-coder model(s) which are incompatible if c.String("general.name") == "deepseek-ai" { return nil, fmt.Errorf("unsupported model: %s", c.String("general.name")) } m := Model{ BytePairEncoding: model.NewBytePairEncoding( c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), EOS: append( []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, ), Layers: make([]Layer, c.Uint("block_count")), Options: &Options{ hiddenSize: int(c.Uint("embedding_length")), numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), headDim: int(c.Uint("attention.key_length")), ropeDim: int(c.Uint("rope.dimension_count")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.freq_scale", 1), }, } m.Cache = kvcache.NewCausalCache(m.Shift) return &m, nil } type SelfAttention struct { Query *nn.Linear `gguf:"attn_q"` Key *nn.Linear `gguf:"attn_k"` Value *nn.Linear `gguf:"attn_v"` Output *nn.Linear `gguf:"attn_output"` RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` } func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) ropeDim := cmp.Or(opts.ropeDim, headDim) query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors)) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) return sa.Output.Forward(ctx, attention) } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil } type MLP struct { Up *nn.Linear `gguf:"ffn_up"` Down *nn.Linear `gguf:"ffn_down"` Gate *nn.Linear `gguf:"ffn_gate"` } func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } type Layer struct { AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` SelfAttention *SelfAttention MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` MLP *MLP } func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { residual := hiddenState hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, opts) // In the final layer (outputs != nil), optimize by pruning to just the token positions // we need logits for. if outputs != nil { hiddenState = hiddenState.Rows(ctx, outputs) residual = residual.Rows(ctx, outputs) } hiddenState = hiddenState.Add(ctx, residual) residual = hiddenState hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = l.MLP.Forward(ctx, hiddenState, opts) return hiddenState.Add(ctx, residual) } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) for i, layer := range m.Layers { m.Cache.SetLayer(i) var outputs ml.Tensor if i == len(m.Layers)-1 { outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) } hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options) } hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) return m.Output.Forward(ctx, hiddenState), nil } func init() { model.Register("llama", New) }