This commit is contained in:
Bruce MacDonald 2025-01-24 16:51:19 -08:00
parent 0d22c0ec1a
commit 60f0b7db76
8 changed files with 817 additions and 11 deletions

4
.gitignore vendored
View File

@ -12,4 +12,6 @@ test_data
*.crt
llama/build
__debug_bin*
llama/vendor
llama/vendor
model/model_test/testdata/*/
!model/model_test/testdata/*.*

View File

@ -24,6 +24,15 @@ type Backend interface {
NewContext() Context
}
type GraphLayer struct {
Name string `json:"name"`
Shape []int64 `json:"shape"`
}
type Graph struct {
Graph []GraphLayer `json:"graph"`
}
var backends = make(map[string]func(*os.File) (Backend, error))
func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
@ -50,6 +59,10 @@ type Context interface {
Forward(Tensor)
Compute(Tensor) Tensor
Close() error
SetDebug(bool)
Trace(string, Tensor)
GetTrace() Graph
}
type Tensor interface {

View File

@ -222,6 +222,7 @@ func (b *Backend) NewContext() ml.Context {
C.size_t(nodes),
true,
),
traceGraph: ml.Graph{},
}
}
@ -232,6 +233,9 @@ type Context struct {
sched *C.struct_ggml_backend_sched
graph *C.struct_ggml_cgraph
nodes int
debug bool
traceGraph ml.Graph
}
func (c *Context) Forward(t ml.Tensor) {
@ -320,6 +324,34 @@ func (c *Context) Close() error {
return nil
}
func (c *Context) SetDebug(debug bool) {
c.debug = debug
}
func (c *Context) Trace(name string, t ml.Tensor) {
if !c.debug {
return
}
shape := t.Shape()
shapeArr := make([]int64, 4)
for i := 0; i < len(shape); i++ {
shapeArr[i] = shape[i]
}
c.traceGraph.Graph = append(
c.traceGraph.Graph,
ml.GraphLayer{
Name: name,
Shape: shapeArr,
},
)
}
func (c *Context) GetTrace() ml.Graph {
return c.traceGraph
}
type Tensor struct {
t *C.struct_ggml_tensor
data []byte
@ -555,16 +587,19 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
return &Tensor{
t: C.ggml_rope_ext(
ctx.(*Context).ctx, t.t, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(ropeDim),
131072, // YaRN n_ctx_train
ropeTypeNorm, // ROPE_TYPE_NORM
C.float(ropeBase),
C.float(ropeScale),
0., // YaRN ext_factor
1., // YaRN attn_factor
32., // YaRN beta_fast
1., // YaRN beta_slow
ctx.(*Context).ctx,
t.t, // a tensor
positionIDs.(*Tensor).t, // b tensor with dims [512, 1, 1, 1]
nil, // c tensor (not shown in log)
C.int(64), // n_dims: 64
2, // mode: 2 (ropeTypeNeox = 2)
C.int(32768), // n_ctx_orig: 32768
C.float(1000000.0), // freq_base: 1000000.000000
C.float(1.0), // freq_scale: 1.000000
C.float(0.0), // ext_factor: 0.000000
C.float(1.0), // attn_factor: 1.000000
C.float(32.0), // beta_fast: 32.000000
C.float(1.0), // beta_slow: 1.000000
),
}
}

169
model/README.md Normal file
View File

@ -0,0 +1,169 @@
# Ollama Models
!! This is a work in progress document !!
## Architecture
```mermaid
graph TB
subgraph Models["Model Layer: LLM Implementations"]
direction TB
llama["llama/model.go"]
mllama["mllama/model.go"]
qwen["qwen2/model.go"]
qwen_vl["qwen2vl/model.go"]
pixtral["pixtral/"]
note1["Each model implements a specific architecture
- Defines model parameters
- Handles tokenization
- Implements forward pass
- Manages model weights"]
end
subgraph ML_Ops["Neural Network Operations"]
direction TB
nn_ops["nn/
linear.go - Matrix operations
embedding.go - Token embeddings
normalization.go - Layer normalization
convolution.go - Conv operations"]
backend["ml/backend.go
Hardware Abstraction Layer
- Defines tensor operations
- Manages computation graphs
- Handles memory allocation"]
note2["Common neural net operations
used across different models
- Abstracts hardware details
- Provides unified API
- Manages computation flow"]
end
subgraph GGML["Hardware Execution Layer"]
direction TB
ggml["ggml.go
CGO Interface
- Bridges Go and C++
- Handles type conversion
- Manages memory between languages"]
subgraph Hardware_Specific["Hardware-Specific Implementations"]
direction LR
cpu["ggml-cpu.h
CPU optimized ops"]
cuda["ggml-cuda.h
NVIDIA GPU ops"]
metal["ggml-metal.h
Apple GPU ops"]
vulkan["ggml-vulkan.h
Cross-platform GPU"]
opencl["ggml-opencl.h
OpenCL acceleration"]
end
note3["GGML provides optimized
implementations for each hardware:
- Automatic dispatch
- Hardware-specific optimizations
- Memory management
- Parallel execution"]
end
%% Connections with explanations
Models --> |"Makes high-level calls
(e.g., self-attention)"| ML_Ops
ML_Ops --> |"Translates to tensor operations
(e.g., matmul, softmax)"| GGML
GGML --> |"Executes optimized code
on target hardware"| Hardware_Specific
%% Styling
classDef model fill:#fff,stroke:#01579b,stroke-width:2px
classDef ml fill:#fff,stroke:#e65100,stroke-width:2px
classDef hw fill:#fff,stroke:#b71c1c,stroke-width:2px
classDef note fill:#fff,stroke:#666,stroke-dasharray: 5 5
class llama,mllama,qwen,qwen_vl,pixtral model
class nn_ops,backend ml
class ggml,cpu,cuda,metal,vulkan,opencl hw
class note1,note2,note3 note
%% Style subgraphs
style Models fill:#fff,stroke:#01579b,stroke-width:2px
style ML_Ops fill:#fff,stroke:#e65100,stroke-width:2px
style GGML fill:#fff,stroke:#b71c1c,stroke-width:2px
style Hardware_Specific fill:#fff,stroke:#b71c1c,stroke-width:1px
```
## Adding support for a new model to Ollama
1. Clone the Ollama repo and get it running locally: https://github.com/ollama/ollama/blob/main/docs/development.md
2. Get the original model (research code) running locally. This will 99.99% of the time be a Python repository.
3. Get a dump of the graph built with Pytorch or Safetensors. Use this snippet to do so.
```python
import torch
import sys
from safetensors.torch import load_file
def extract_graph(model_path):
if model_path.endswith('.safetensors'):
state_dict = load_file(model_path)
else:
state_dict = torch.load(model_path, weights_only=True)
graph = []
for name, tensor in state_dict.items():
if isinstance(tensor, torch.Tensor):
graph.append({
"name": name,
"shape": list(tensor.shape)
})
print("{")
print(' "graph": [')
for i, layer in enumerate(graph):
comma = "," if i < len(graph) - 1 else ""
print(f' {{"name": "{layer["name"]}", "shape": {layer["shape"]}}}{comma}')
print(" ]")
print("}")
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python extract.py <path/to/model>")
sys.exit(1)
extract_graph(sys.argv[1])
```
4. Look at a previous model implementation pull request and copy the structure of the files needed. We will need:
1. A `model/<model-name>` directory
2. A `model/<model-name>/model.go` file to implement the architecture and forward pass.
3. A `model/<model-name>/convert.go` file to implement to conversion from pytorch/safetensors to ggml.
4. `model/<model-name>/model_test.go` and `model/<model-name>/convert_test.go` files for testing.
5. Modify main paths to make this new model accessible.
5. Open a draft pull request in the `ollama/ollama` repo, as a place to ask questions and get answers from Ollama maintainers.
6. Implement conversion from the model weights (pytorch, safetensors) to ggml in the `model/<your-model>/convert.go` file. Reference other `convert.go` files.
7. Create a Modelfile that only references the pytorch/safetensor directory. We will handle the other fields later.
Modelfile:
```
FROM /path/to/model
```
Use `ollama create` to convert the model:
`go run . create <my-model> -f /path/to/Modelfie`
6. Implement the `New()` and `Forward()` logic in `model/<your-model>/model.go` . Reference other `model.go` files.
Run the model and get the debug output of the forward pass to compare with the output of the research implementation from step 1:
`OLLAMA_DEBUG=1 go run . run <my-model>`
7. (maybe) Implement a new tokenizer, if needed.
8. Test text generation, this step requires knowing the prompt format:
`go run . run <my-model> "hello"`
9. Add tests to `model/<your-model>/model_test.go` and `model/<your-model>/convert_test.go`
10. Push changes to `ollama/ollama` pull request, and move the pull request out of the draft state.
11. Push model to ollama.com:
1. Find model prompt format and convert it to a Go template.
2. Create a Modelfile `FROM` the converted gguf, add the `TEMPLATE`, `LICENSE`, and parameters if needed.
3. `ollama create <your-namespace>/<your-model> -f /path/to/Modelfile`
4. `ollama push <your-namespace>/<your-model>`
12. Run end-to-end integration tests.

View File

@ -0,0 +1,91 @@
package modeltest
import (
"encoding/json"
"os"
"path/filepath"
"reflect"
"testing"
"github.com/ollama/ollama/cache"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
_ "github.com/ollama/ollama/model/qwen2"
)
func TestForward(t *testing.T) {
cases := []string{
"qwen2",
// Add more model architectures here...
}
for _, tt := range cases {
t.Run(tt, func(t *testing.T) {
t.Parallel()
p := filepath.Join("testdata", tt)
if testing.Short() {
t.Skip("skipping in short mode")
} else if _, err := os.Stat(p); err != nil {
t.Skipf("%s not found", p)
}
f, err := os.CreateTemp(t.TempDir(), "f16")
if err != nil {
t.Fatal(err)
}
defer func() {
f.Close()
os.Remove(f.Name())
}()
if err := convert.ConvertModel(os.DirFS(p), f); err != nil {
t.Fatal(err)
}
m, err := model.New(f.Name())
if err != nil {
t.Fatal(err)
}
b := m.Backend()
ctx := b.NewContext()
ctx.SetDebug(true)
// Run forward pass
_, err = model.Forward(ctx, m, model.WithCache(cache.NewCausalCache(m.Backend(), 2048, ml.DTypeF32)))
if err != nil {
t.Fatal(err)
}
// Validate the graph layers
data, err := os.ReadFile(filepath.Join("testdata", tt+".json"))
if err != nil {
t.Fatal(err)
}
var expected ml.Graph
if err := json.Unmarshal(data, &expected); err != nil {
t.Fatal(err)
}
result := ctx.GetTrace()
if len(result.Graph) != len(expected.Graph) {
t.Errorf("expected %d layers, got %d", len(expected.Graph), len(result.Graph))
}
for i, layer := range expected.Graph {
if i >= len(result.Graph) {
break
}
actual := result.Graph[i]
if layer.Name != actual.Name {
t.Errorf("layer %d: expected name %s, got %s", i, layer.Name, actual.Name)
}
if !reflect.DeepEqual(layer.Shape, actual.Shape) {
t.Errorf("layer %d: expected shape %v, got %v", i, layer.Shape, actual.Shape)
}
}
})
}
}

294
model/model_test/testdata/qwen2.json vendored Normal file
View File

@ -0,0 +1,294 @@
{
"graph": [
{"name": "model.embed_tokens.weight", "shape": [151936, 896]},
{"name": "model.layers.0.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.0.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.0.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.0.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.0.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.0.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.0.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.0.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.0.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.0.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.0.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.0.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.1.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.1.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.1.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.1.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.1.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.1.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.1.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.1.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.1.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.1.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.1.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.1.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.10.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.10.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.10.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.10.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.10.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.10.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.10.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.10.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.10.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.10.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.10.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.10.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.11.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.11.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.11.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.11.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.11.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.11.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.11.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.11.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.11.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.11.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.11.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.11.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.12.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.12.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.12.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.12.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.12.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.12.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.12.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.12.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.12.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.12.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.12.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.12.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.13.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.13.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.13.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.13.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.13.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.13.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.13.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.13.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.13.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.13.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.13.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.13.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.14.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.14.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.14.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.14.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.14.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.14.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.14.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.14.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.14.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.14.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.14.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.14.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.15.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.15.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.15.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.15.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.15.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.15.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.15.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.15.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.15.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.15.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.15.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.15.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.16.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.16.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.16.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.16.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.16.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.16.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.16.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.16.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.16.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.16.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.16.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.16.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.17.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.17.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.17.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.17.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.17.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.17.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.17.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.17.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.17.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.17.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.17.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.17.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.18.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.18.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.18.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.18.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.18.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.18.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.18.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.18.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.18.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.18.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.18.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.18.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.19.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.19.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.19.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.19.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.19.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.19.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.19.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.19.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.19.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.19.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.19.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.19.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.2.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.2.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.2.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.2.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.2.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.2.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.2.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.2.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.2.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.2.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.2.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.2.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.20.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.20.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.20.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.20.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.20.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.20.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.20.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.20.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.20.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.20.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.20.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.20.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.21.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.21.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.21.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.21.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.21.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.21.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.21.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.21.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.21.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.21.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.21.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.21.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.22.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.22.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.22.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.22.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.22.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.22.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.22.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.22.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.22.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.22.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.22.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.22.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.23.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.23.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.23.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.23.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.23.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.23.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.23.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.23.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.23.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.23.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.23.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.23.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.3.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.3.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.3.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.3.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.3.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.3.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.3.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.3.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.3.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.3.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.3.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.3.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.4.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.4.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.4.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.4.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.4.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.4.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.4.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.4.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.4.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.4.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.4.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.4.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.5.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.5.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.5.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.5.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.5.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.5.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.5.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.5.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.5.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.5.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.5.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.5.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.6.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.6.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.6.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.6.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.6.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.6.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.6.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.6.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.6.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.6.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.6.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.6.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.7.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.7.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.7.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.7.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.7.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.7.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.7.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.7.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.7.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.7.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.7.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.7.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.8.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.8.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.8.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.8.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.8.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.8.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.8.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.8.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.8.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.8.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.8.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.8.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.layers.9.input_layernorm.weight", "shape": [896]},
{"name": "model.layers.9.mlp.down_proj.weight", "shape": [896, 4864]},
{"name": "model.layers.9.mlp.gate_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.9.mlp.up_proj.weight", "shape": [4864, 896]},
{"name": "model.layers.9.post_attention_layernorm.weight", "shape": [896]},
{"name": "model.layers.9.self_attn.k_proj.bias", "shape": [128]},
{"name": "model.layers.9.self_attn.k_proj.weight", "shape": [128, 896]},
{"name": "model.layers.9.self_attn.o_proj.weight", "shape": [896, 896]},
{"name": "model.layers.9.self_attn.q_proj.bias", "shape": [896]},
{"name": "model.layers.9.self_attn.q_proj.weight", "shape": [896, 896]},
{"name": "model.layers.9.self_attn.v_proj.bias", "shape": [128]},
{"name": "model.layers.9.self_attn.v_proj.weight", "shape": [128, 896]},
{"name": "model.norm.weight", "shape": [896]}
]
}

201
model/qwen2/model.go Normal file
View File

@ -0,0 +1,201 @@
package qwen2
import (
"fmt"
"log/slog"
"math"
"github.com/ollama/ollama/cache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
)
type Options struct {
hiddenSize, numHeads, numKVHeads int64
eps, ropeBase, ropeScale float32
ropeDim uint32
}
type Model struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*Options
}
func New(c ml.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.BytePairEncoding{
Pretokenizer: c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
Vocabulary: &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: c.Uint("tokenizer.ggml.bos_token_id"),
EOS: c.Uint("tokenizer.ggml.eos_token_id"),
},
},
Layers: make([]Layer, c.Uint("block_count")),
Options: &Options{
hiddenSize: int64(c.Uint("embedding_length")),
numHeads: int64(c.Uint("attention.head_count")),
numKVHeads: int64(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count", 64),
},
}
slog.Debug("model configuration",
"arch", "qwen2",
"vocab_size", len(c.Strings("tokenizer.ggml.tokens")),
"n_merges", len(c.Strings("tokenizer.ggml.merges")),
"n_ctx_train", c.Uint("context_length"),
"n_embd", m.hiddenSize,
"n_layer", len(m.Layers),
"n_head", m.numHeads,
"n_head_kv", m.numKVHeads,
"n_rot", m.ropeDim,
"f_norm_rms_eps", m.eps,
"rope_freq_base", m.ropeBase,
"rope_freq_scale", m.ropeScale,
"bos_token_id", c.Uint("tokenizer.ggml.bos_token_id"),
"eos_token_id", c.Uint("tokenizer.ggml.eos_token_id"),
)
return m, nil
}
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, inputPositions ml.Tensor, layerIdx int, cache cache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
q := sa.Query.Forward(ctx, hiddenState)
ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.q_proj", layerIdx), q)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, inputPositions, nil, opts.ropeDim, opts.ropeBase, opts.ropeScale)
ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.q_proj.rope", layerIdx), q)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, inputPositions, nil, opts.ropeDim, opts.ropeBase, opts.ropeScale)
ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.k_proj.rope", layerIdx), k)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.v_proj", layerIdx), v)
k, v, mask := cache.Put(ctx, k, v)
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := k.Mulmat(ctx, q)
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
kq = kq.Add(ctx, mask)
kq = kq.Softmax(ctx)
kqv := v.Mulmat(ctx, kq)
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
output := sa.Output.Forward(ctx, kqv)
return output
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, layerIdx int, cache cache.Cache, opts *Options) ml.Tensor {
ctx.Trace(fmt.Sprintf("model.layers.%d.input", layerIdx), hiddenState)
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
ctx.Trace(fmt.Sprintf("model.layers.%d.input_layernorm", layerIdx), hiddenState)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, layerIdx, cache, opts)
ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.output", layerIdx), hiddenState)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.residual", layerIdx), hiddenState)
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
ctx.Trace(fmt.Sprintf("model.layers.%d.post_attention_layernorm", layerIdx), hiddenState)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
ctx.Trace(fmt.Sprintf("model.layers.%d.mlp", layerIdx), hiddenState)
output := hiddenState.Add(ctx, residual)
ctx.Trace(fmt.Sprintf("model.layers.%d.output", layerIdx), output)
return output
}
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
slog.Debug("input tokens", "input_ids", opts.Inputs())
inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs()))
if err != nil {
return nil, err
}
positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions()))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
ctx.Trace("model.embed_tokens", hiddenState)
for i, layer := range m.Layers {
hiddenState = layer.Forward(ctx, hiddenState, positions, i, opts.Cache.Sub(i), m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
ctx.Trace("model.norm", hiddenState)
hiddenState = m.Output.Forward(ctx, hiddenState)
ctx.Trace("model.output", hiddenState)
outputs, err := ctx.FromIntSlice(opts.Outputs(), len(opts.Outputs()))
if err != nil {
return nil, err
}
return hiddenState.Rows(ctx, outputs), nil
}
func init() {
model.Register("qwen2", New)
}

View File

@ -32,6 +32,7 @@ import (
_ "github.com/ollama/ollama/model/llama"
_ "github.com/ollama/ollama/model/mllama"
_ "github.com/ollama/ollama/model/qwen2"
)
// input is an element of the prompt to process, either