mirror of
https://github.com/ollama/ollama.git
synced 2025-04-07 19:38:08 +02:00
working
This commit is contained in:
parent
0d22c0ec1a
commit
60f0b7db76
4
.gitignore
vendored
4
.gitignore
vendored
@ -12,4 +12,6 @@ test_data
|
||||
*.crt
|
||||
llama/build
|
||||
__debug_bin*
|
||||
llama/vendor
|
||||
llama/vendor
|
||||
model/model_test/testdata/*/
|
||||
!model/model_test/testdata/*.*
|
||||
|
@ -24,6 +24,15 @@ type Backend interface {
|
||||
NewContext() Context
|
||||
}
|
||||
|
||||
type GraphLayer struct {
|
||||
Name string `json:"name"`
|
||||
Shape []int64 `json:"shape"`
|
||||
}
|
||||
|
||||
type Graph struct {
|
||||
Graph []GraphLayer `json:"graph"`
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(*os.File) (Backend, error))
|
||||
|
||||
func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
|
||||
@ -50,6 +59,10 @@ type Context interface {
|
||||
Forward(Tensor)
|
||||
Compute(Tensor) Tensor
|
||||
Close() error
|
||||
|
||||
SetDebug(bool)
|
||||
Trace(string, Tensor)
|
||||
GetTrace() Graph
|
||||
}
|
||||
|
||||
type Tensor interface {
|
||||
|
@ -222,6 +222,7 @@ func (b *Backend) NewContext() ml.Context {
|
||||
C.size_t(nodes),
|
||||
true,
|
||||
),
|
||||
traceGraph: ml.Graph{},
|
||||
}
|
||||
}
|
||||
|
||||
@ -232,6 +233,9 @@ type Context struct {
|
||||
sched *C.struct_ggml_backend_sched
|
||||
graph *C.struct_ggml_cgraph
|
||||
nodes int
|
||||
|
||||
debug bool
|
||||
traceGraph ml.Graph
|
||||
}
|
||||
|
||||
func (c *Context) Forward(t ml.Tensor) {
|
||||
@ -320,6 +324,34 @@ func (c *Context) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Context) SetDebug(debug bool) {
|
||||
c.debug = debug
|
||||
}
|
||||
|
||||
func (c *Context) Trace(name string, t ml.Tensor) {
|
||||
if !c.debug {
|
||||
return
|
||||
}
|
||||
|
||||
shape := t.Shape()
|
||||
shapeArr := make([]int64, 4)
|
||||
for i := 0; i < len(shape); i++ {
|
||||
shapeArr[i] = shape[i]
|
||||
}
|
||||
|
||||
c.traceGraph.Graph = append(
|
||||
c.traceGraph.Graph,
|
||||
ml.GraphLayer{
|
||||
Name: name,
|
||||
Shape: shapeArr,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (c *Context) GetTrace() ml.Graph {
|
||||
return c.traceGraph
|
||||
}
|
||||
|
||||
type Tensor struct {
|
||||
t *C.struct_ggml_tensor
|
||||
data []byte
|
||||
@ -555,16 +587,19 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
||||
|
||||
return &Tensor{
|
||||
t: C.ggml_rope_ext(
|
||||
ctx.(*Context).ctx, t.t, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
|
||||
C.int(ropeDim),
|
||||
131072, // YaRN n_ctx_train
|
||||
ropeTypeNorm, // ROPE_TYPE_NORM
|
||||
C.float(ropeBase),
|
||||
C.float(ropeScale),
|
||||
0., // YaRN ext_factor
|
||||
1., // YaRN attn_factor
|
||||
32., // YaRN beta_fast
|
||||
1., // YaRN beta_slow
|
||||
ctx.(*Context).ctx,
|
||||
t.t, // a tensor
|
||||
positionIDs.(*Tensor).t, // b tensor with dims [512, 1, 1, 1]
|
||||
nil, // c tensor (not shown in log)
|
||||
C.int(64), // n_dims: 64
|
||||
2, // mode: 2 (ropeTypeNeox = 2)
|
||||
C.int(32768), // n_ctx_orig: 32768
|
||||
C.float(1000000.0), // freq_base: 1000000.000000
|
||||
C.float(1.0), // freq_scale: 1.000000
|
||||
C.float(0.0), // ext_factor: 0.000000
|
||||
C.float(1.0), // attn_factor: 1.000000
|
||||
C.float(32.0), // beta_fast: 32.000000
|
||||
C.float(1.0), // beta_slow: 1.000000
|
||||
),
|
||||
}
|
||||
}
|
||||
|
169
model/README.md
Normal file
169
model/README.md
Normal file
@ -0,0 +1,169 @@
|
||||
# Ollama Models
|
||||
|
||||
!! This is a work in progress document !!
|
||||
|
||||
## Architecture
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph Models["Model Layer: LLM Implementations"]
|
||||
direction TB
|
||||
llama["llama/model.go"]
|
||||
mllama["mllama/model.go"]
|
||||
qwen["qwen2/model.go"]
|
||||
qwen_vl["qwen2vl/model.go"]
|
||||
pixtral["pixtral/"]
|
||||
|
||||
note1["Each model implements a specific architecture
|
||||
- Defines model parameters
|
||||
- Handles tokenization
|
||||
- Implements forward pass
|
||||
- Manages model weights"]
|
||||
end
|
||||
|
||||
subgraph ML_Ops["Neural Network Operations"]
|
||||
direction TB
|
||||
nn_ops["nn/
|
||||
linear.go - Matrix operations
|
||||
embedding.go - Token embeddings
|
||||
normalization.go - Layer normalization
|
||||
convolution.go - Conv operations"]
|
||||
|
||||
backend["ml/backend.go
|
||||
Hardware Abstraction Layer
|
||||
- Defines tensor operations
|
||||
- Manages computation graphs
|
||||
- Handles memory allocation"]
|
||||
|
||||
note2["Common neural net operations
|
||||
used across different models
|
||||
- Abstracts hardware details
|
||||
- Provides unified API
|
||||
- Manages computation flow"]
|
||||
end
|
||||
|
||||
subgraph GGML["Hardware Execution Layer"]
|
||||
direction TB
|
||||
ggml["ggml.go
|
||||
CGO Interface
|
||||
- Bridges Go and C++
|
||||
- Handles type conversion
|
||||
- Manages memory between languages"]
|
||||
|
||||
subgraph Hardware_Specific["Hardware-Specific Implementations"]
|
||||
direction LR
|
||||
cpu["ggml-cpu.h
|
||||
CPU optimized ops"]
|
||||
cuda["ggml-cuda.h
|
||||
NVIDIA GPU ops"]
|
||||
metal["ggml-metal.h
|
||||
Apple GPU ops"]
|
||||
vulkan["ggml-vulkan.h
|
||||
Cross-platform GPU"]
|
||||
opencl["ggml-opencl.h
|
||||
OpenCL acceleration"]
|
||||
end
|
||||
|
||||
note3["GGML provides optimized
|
||||
implementations for each hardware:
|
||||
- Automatic dispatch
|
||||
- Hardware-specific optimizations
|
||||
- Memory management
|
||||
- Parallel execution"]
|
||||
end
|
||||
|
||||
%% Connections with explanations
|
||||
Models --> |"Makes high-level calls
|
||||
(e.g., self-attention)"| ML_Ops
|
||||
ML_Ops --> |"Translates to tensor operations
|
||||
(e.g., matmul, softmax)"| GGML
|
||||
GGML --> |"Executes optimized code
|
||||
on target hardware"| Hardware_Specific
|
||||
|
||||
%% Styling
|
||||
classDef model fill:#fff,stroke:#01579b,stroke-width:2px
|
||||
classDef ml fill:#fff,stroke:#e65100,stroke-width:2px
|
||||
classDef hw fill:#fff,stroke:#b71c1c,stroke-width:2px
|
||||
classDef note fill:#fff,stroke:#666,stroke-dasharray: 5 5
|
||||
|
||||
class llama,mllama,qwen,qwen_vl,pixtral model
|
||||
class nn_ops,backend ml
|
||||
class ggml,cpu,cuda,metal,vulkan,opencl hw
|
||||
class note1,note2,note3 note
|
||||
|
||||
%% Style subgraphs
|
||||
style Models fill:#fff,stroke:#01579b,stroke-width:2px
|
||||
style ML_Ops fill:#fff,stroke:#e65100,stroke-width:2px
|
||||
style GGML fill:#fff,stroke:#b71c1c,stroke-width:2px
|
||||
style Hardware_Specific fill:#fff,stroke:#b71c1c,stroke-width:1px
|
||||
```
|
||||
|
||||
## Adding support for a new model to Ollama
|
||||
|
||||
1. Clone the Ollama repo and get it running locally: https://github.com/ollama/ollama/blob/main/docs/development.md
|
||||
2. Get the original model (research code) running locally. This will 99.99% of the time be a Python repository.
|
||||
3. Get a dump of the graph built with Pytorch or Safetensors. Use this snippet to do so.
|
||||
```python
|
||||
import torch
|
||||
import sys
|
||||
from safetensors.torch import load_file
|
||||
|
||||
def extract_graph(model_path):
|
||||
if model_path.endswith('.safetensors'):
|
||||
state_dict = load_file(model_path)
|
||||
else:
|
||||
state_dict = torch.load(model_path, weights_only=True)
|
||||
|
||||
graph = []
|
||||
for name, tensor in state_dict.items():
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
graph.append({
|
||||
"name": name,
|
||||
"shape": list(tensor.shape)
|
||||
})
|
||||
|
||||
print("{")
|
||||
print(' "graph": [')
|
||||
for i, layer in enumerate(graph):
|
||||
comma = "," if i < len(graph) - 1 else ""
|
||||
print(f' {{"name": "{layer["name"]}", "shape": {layer["shape"]}}}{comma}')
|
||||
print(" ]")
|
||||
print("}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python extract.py <path/to/model>")
|
||||
sys.exit(1)
|
||||
|
||||
extract_graph(sys.argv[1])
|
||||
```
|
||||
4. Look at a previous model implementation pull request and copy the structure of the files needed. We will need:
|
||||
1. A `model/<model-name>` directory
|
||||
2. A `model/<model-name>/model.go` file to implement the architecture and forward pass.
|
||||
3. A `model/<model-name>/convert.go` file to implement to conversion from pytorch/safetensors to ggml.
|
||||
4. `model/<model-name>/model_test.go` and `model/<model-name>/convert_test.go` files for testing.
|
||||
5. Modify main paths to make this new model accessible.
|
||||
5. Open a draft pull request in the `ollama/ollama` repo, as a place to ask questions and get answers from Ollama maintainers.
|
||||
6. Implement conversion from the model weights (pytorch, safetensors) to ggml in the `model/<your-model>/convert.go` file. Reference other `convert.go` files.
|
||||
7. Create a Modelfile that only references the pytorch/safetensor directory. We will handle the other fields later.
|
||||
Modelfile:
|
||||
```
|
||||
FROM /path/to/model
|
||||
```
|
||||
Use `ollama create` to convert the model:
|
||||
`go run . create <my-model> -f /path/to/Modelfie`
|
||||
6. Implement the `New()` and `Forward()` logic in `model/<your-model>/model.go` . Reference other `model.go` files.
|
||||
|
||||
Run the model and get the debug output of the forward pass to compare with the output of the research implementation from step 1:
|
||||
`OLLAMA_DEBUG=1 go run . run <my-model>`
|
||||
7. (maybe) Implement a new tokenizer, if needed.
|
||||
8. Test text generation, this step requires knowing the prompt format:
|
||||
`go run . run <my-model> "hello"`
|
||||
9. Add tests to `model/<your-model>/model_test.go` and `model/<your-model>/convert_test.go`
|
||||
10. Push changes to `ollama/ollama` pull request, and move the pull request out of the draft state.
|
||||
11. Push model to ollama.com:
|
||||
1. Find model prompt format and convert it to a Go template.
|
||||
2. Create a Modelfile `FROM` the converted gguf, add the `TEMPLATE`, `LICENSE`, and parameters if needed.
|
||||
3. `ollama create <your-namespace>/<your-model> -f /path/to/Modelfile`
|
||||
4. `ollama push <your-namespace>/<your-model>`
|
||||
12. Run end-to-end integration tests.
|
91
model/model_test/model_test.go
Normal file
91
model/model_test/model_test.go
Normal file
@ -0,0 +1,91 @@
|
||||
package modeltest
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/cache"
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
_ "github.com/ollama/ollama/model/qwen2"
|
||||
)
|
||||
|
||||
func TestForward(t *testing.T) {
|
||||
cases := []string{
|
||||
"qwen2",
|
||||
// Add more model architectures here...
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
p := filepath.Join("testdata", tt)
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
} else if _, err := os.Stat(p); err != nil {
|
||||
t.Skipf("%s not found", p)
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp(t.TempDir(), "f16")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
f.Close()
|
||||
os.Remove(f.Name())
|
||||
}()
|
||||
|
||||
if err := convert.ConvertModel(os.DirFS(p), f); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
m, err := model.New(f.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
b := m.Backend()
|
||||
ctx := b.NewContext()
|
||||
ctx.SetDebug(true)
|
||||
|
||||
// Run forward pass
|
||||
_, err = model.Forward(ctx, m, model.WithCache(cache.NewCausalCache(m.Backend(), 2048, ml.DTypeF32)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Validate the graph layers
|
||||
data, err := os.ReadFile(filepath.Join("testdata", tt+".json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var expected ml.Graph
|
||||
if err := json.Unmarshal(data, &expected); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result := ctx.GetTrace()
|
||||
|
||||
if len(result.Graph) != len(expected.Graph) {
|
||||
t.Errorf("expected %d layers, got %d", len(expected.Graph), len(result.Graph))
|
||||
}
|
||||
|
||||
for i, layer := range expected.Graph {
|
||||
if i >= len(result.Graph) {
|
||||
break
|
||||
}
|
||||
actual := result.Graph[i]
|
||||
if layer.Name != actual.Name {
|
||||
t.Errorf("layer %d: expected name %s, got %s", i, layer.Name, actual.Name)
|
||||
}
|
||||
if !reflect.DeepEqual(layer.Shape, actual.Shape) {
|
||||
t.Errorf("layer %d: expected shape %v, got %v", i, layer.Shape, actual.Shape)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
294
model/model_test/testdata/qwen2.json
vendored
Normal file
294
model/model_test/testdata/qwen2.json
vendored
Normal file
@ -0,0 +1,294 @@
|
||||
{
|
||||
"graph": [
|
||||
{"name": "model.embed_tokens.weight", "shape": [151936, 896]},
|
||||
{"name": "model.layers.0.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.0.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.0.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.0.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.0.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.0.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.0.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.0.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.0.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.0.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.0.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.0.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.1.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.1.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.1.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.1.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.1.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.1.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.1.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.1.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.1.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.1.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.1.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.1.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.10.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.10.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.10.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.10.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.10.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.10.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.10.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.10.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.10.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.10.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.10.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.10.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.11.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.11.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.11.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.11.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.11.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.11.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.11.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.11.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.11.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.11.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.11.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.11.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.12.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.12.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.12.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.12.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.12.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.12.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.12.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.12.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.12.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.12.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.12.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.12.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.13.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.13.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.13.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.13.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.13.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.13.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.13.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.13.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.13.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.13.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.13.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.13.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.14.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.14.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.14.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.14.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.14.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.14.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.14.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.14.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.14.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.14.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.14.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.14.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.15.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.15.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.15.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.15.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.15.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.15.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.15.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.15.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.15.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.15.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.15.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.15.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.16.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.16.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.16.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.16.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.16.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.16.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.16.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.16.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.16.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.16.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.16.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.16.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.17.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.17.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.17.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.17.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.17.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.17.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.17.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.17.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.17.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.17.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.17.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.17.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.18.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.18.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.18.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.18.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.18.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.18.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.18.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.18.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.18.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.18.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.18.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.18.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.19.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.19.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.19.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.19.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.19.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.19.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.19.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.19.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.19.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.19.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.19.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.19.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.2.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.2.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.2.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.2.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.2.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.2.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.2.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.2.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.2.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.2.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.2.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.2.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.20.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.20.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.20.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.20.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.20.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.20.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.20.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.20.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.20.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.20.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.20.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.20.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.21.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.21.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.21.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.21.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.21.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.21.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.21.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.21.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.21.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.21.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.21.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.21.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.22.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.22.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.22.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.22.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.22.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.22.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.22.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.22.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.22.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.22.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.22.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.22.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.23.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.23.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.23.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.23.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.23.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.23.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.23.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.23.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.23.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.23.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.23.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.23.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.3.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.3.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.3.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.3.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.3.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.3.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.3.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.3.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.3.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.3.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.3.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.3.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.4.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.4.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.4.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.4.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.4.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.4.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.4.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.4.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.4.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.4.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.4.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.4.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.5.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.5.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.5.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.5.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.5.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.5.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.5.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.5.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.5.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.5.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.5.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.5.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.6.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.6.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.6.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.6.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.6.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.6.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.6.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.6.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.6.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.6.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.6.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.6.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.7.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.7.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.7.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.7.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.7.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.7.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.7.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.7.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.7.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.7.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.7.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.7.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.8.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.8.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.8.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.8.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.8.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.8.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.8.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.8.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.8.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.8.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.8.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.8.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.9.input_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.9.mlp.down_proj.weight", "shape": [896, 4864]},
|
||||
{"name": "model.layers.9.mlp.gate_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.9.mlp.up_proj.weight", "shape": [4864, 896]},
|
||||
{"name": "model.layers.9.post_attention_layernorm.weight", "shape": [896]},
|
||||
{"name": "model.layers.9.self_attn.k_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.9.self_attn.k_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.layers.9.self_attn.o_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.9.self_attn.q_proj.bias", "shape": [896]},
|
||||
{"name": "model.layers.9.self_attn.q_proj.weight", "shape": [896, 896]},
|
||||
{"name": "model.layers.9.self_attn.v_proj.bias", "shape": [128]},
|
||||
{"name": "model.layers.9.self_attn.v_proj.weight", "shape": [128, 896]},
|
||||
{"name": "model.norm.weight", "shape": [896]}
|
||||
]
|
||||
}
|
201
model/qwen2/model.go
Normal file
201
model/qwen2/model.go
Normal file
@ -0,0 +1,201 @@
|
||||
package qwen2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/cache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
hiddenSize, numHeads, numKVHeads int64
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*Options
|
||||
}
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
BytePairEncoding: model.BytePairEncoding{
|
||||
Pretokenizer: c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
Vocabulary: &model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
BOS: c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
EOS: c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
},
|
||||
},
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: &Options{
|
||||
hiddenSize: int64(c.Uint("embedding_length")),
|
||||
numHeads: int64(c.Uint("attention.head_count")),
|
||||
numKVHeads: int64(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeDim: c.Uint("rope.dimension_count", 64),
|
||||
},
|
||||
}
|
||||
|
||||
slog.Debug("model configuration",
|
||||
"arch", "qwen2",
|
||||
"vocab_size", len(c.Strings("tokenizer.ggml.tokens")),
|
||||
"n_merges", len(c.Strings("tokenizer.ggml.merges")),
|
||||
"n_ctx_train", c.Uint("context_length"),
|
||||
"n_embd", m.hiddenSize,
|
||||
"n_layer", len(m.Layers),
|
||||
"n_head", m.numHeads,
|
||||
"n_head_kv", m.numKVHeads,
|
||||
"n_rot", m.ropeDim,
|
||||
"f_norm_rms_eps", m.eps,
|
||||
"rope_freq_base", m.ropeBase,
|
||||
"rope_freq_scale", m.ropeScale,
|
||||
"bos_token_id", c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
"eos_token_id", c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
type SelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, inputPositions ml.Tensor, layerIdx int, cache cache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.q_proj", layerIdx), q)
|
||||
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, inputPositions, nil, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.q_proj.rope", layerIdx), q)
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, inputPositions, nil, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.k_proj.rope", layerIdx), k)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.v_proj", layerIdx), v)
|
||||
|
||||
k, v, mask := cache.Put(ctx, k, v)
|
||||
|
||||
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := k.Mulmat(ctx, q)
|
||||
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
kq = kq.Add(ctx, mask)
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := v.Mulmat(ctx, kq)
|
||||
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
output := sa.Output.Forward(ctx, kqv)
|
||||
return output
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *SelfAttention
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, layerIdx int, cache cache.Cache, opts *Options) ml.Tensor {
|
||||
ctx.Trace(fmt.Sprintf("model.layers.%d.input", layerIdx), hiddenState)
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
ctx.Trace(fmt.Sprintf("model.layers.%d.input_layernorm", layerIdx), hiddenState)
|
||||
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, layerIdx, cache, opts)
|
||||
ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.output", layerIdx), hiddenState)
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
ctx.Trace(fmt.Sprintf("model.layers.%d.self_attn.residual", layerIdx), hiddenState)
|
||||
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
ctx.Trace(fmt.Sprintf("model.layers.%d.post_attention_layernorm", layerIdx), hiddenState)
|
||||
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
||||
ctx.Trace(fmt.Sprintf("model.layers.%d.mlp", layerIdx), hiddenState)
|
||||
|
||||
output := hiddenState.Add(ctx, residual)
|
||||
ctx.Trace(fmt.Sprintf("model.layers.%d.output", layerIdx), output)
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
slog.Debug("input tokens", "input_ids", opts.Inputs())
|
||||
inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
ctx.Trace("model.embed_tokens", hiddenState)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, i, opts.Cache.Sub(i), m.Options)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
ctx.Trace("model.norm", hiddenState)
|
||||
|
||||
hiddenState = m.Output.Forward(ctx, hiddenState)
|
||||
ctx.Trace("model.output", hiddenState)
|
||||
|
||||
outputs, err := ctx.FromIntSlice(opts.Outputs(), len(opts.Outputs()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hiddenState.Rows(ctx, outputs), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("qwen2", New)
|
||||
}
|
@ -32,6 +32,7 @@ import (
|
||||
|
||||
_ "github.com/ollama/ollama/model/llama"
|
||||
_ "github.com/ollama/ollama/model/mllama"
|
||||
_ "github.com/ollama/ollama/model/qwen2"
|
||||
)
|
||||
|
||||
// input is an element of the prompt to process, either
|
||||
|
Loading…
x
Reference in New Issue
Block a user