mirror of
https://github.com/ollama/ollama.git
synced 2025-03-18 05:41:43 +01:00
next ollama runner (#7913)
feat: add new Ollama engine using ggml through cgo This change introduces a new way to run pretrained models. It introduces 3 high level interfaces and a bunch of smaller helper interfaces to facilitate this. - `model.Model` defines the interface for a model architecture. Models such as `llama` and `mllama`, which are provided as examples, can implement the model's forward propagation in the `Forward` method. This method will be called to generate completions. This interface can be found in `model/model.go` - `ml.Backend` defines the interface for a backend tensor library, in this case `ggml`. Among other things, a Backend is responsible for loading a pretrained model into hardware (GPU, CPU, etc) and providing an interface for Models to access loaded tensors. This interface can be found in `ml/backend.go` - `ml.Tensor` defines the interface for a tensor and tensor operations This is the first implementation of the new engine. Follow up PRs will implement more features: - non-greedy sampling (#8410) - integration with Ollama and KV caching (#8301) - more model support (#9080) with more coming soon Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
This commit is contained in:
parent
8cf16063a5
commit
58245413f4
cache
convert
convert.goconvert_bert.goconvert_commandr.goconvert_gemma.goconvert_gemma2.goconvert_gemma2_adapter.goconvert_llama.goconvert_llama_adapter.goconvert_mixtral.goconvert_phi3.goconvert_qwen2.goconvert_test.go
fs
llm
ml
model
parser
sample
server
create.goimages.gomodel.goroutes.goroutes_create_test.goroutes_generate_test.goroutes_test.gosched.gosched_test.go
template
63
cache/cache.go
vendored
Normal file
63
cache/cache.go
vendored
Normal file
@ -0,0 +1,63 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Position int
|
||||
}
|
||||
|
||||
type Cache interface {
|
||||
Sub(i int) Cache
|
||||
Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor)
|
||||
}
|
||||
|
||||
type Simple struct {
|
||||
DType ml.DType
|
||||
Capacity int
|
||||
|
||||
keys, values []ml.Tensor
|
||||
}
|
||||
|
||||
func (c *Simple) Sub(i int) Cache {
|
||||
if i >= len(c.keys) {
|
||||
c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...)
|
||||
c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...)
|
||||
}
|
||||
|
||||
return &Simple{
|
||||
keys: c.keys[i : i+1],
|
||||
values: c.values[i : i+1],
|
||||
Capacity: c.Capacity,
|
||||
DType: c.DType,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Simple) Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor) {
|
||||
if c.keys[0] == nil || c.values[0] == nil {
|
||||
c.keys[0] = ctx.Zeros(c.DType, int(key.Dim(0)*key.Dim(1))*c.Capacity)
|
||||
c.values[0] = ctx.Zeros(c.DType, int(value.Dim(0)*value.Dim(1))*c.Capacity)
|
||||
}
|
||||
|
||||
ctx.Forward(key.Copy(ctx, c.keys[0].View(ctx, int(key.Stride(2))*opts.Position, int(key.Dim(0)*key.Dim(1)*key.Dim(2)))))
|
||||
ctx.Forward(value.Copy(ctx, c.values[0].View(ctx, int(value.Stride(2))*opts.Position, int(value.Dim(0)*value.Dim(1)*value.Dim(2)))))
|
||||
|
||||
n := min(c.Capacity, int(key.Dim(2))+opts.Position)
|
||||
|
||||
key = c.keys[0].View(ctx, 0,
|
||||
int(key.Dim(0)), int(key.Stride(1)),
|
||||
int(key.Dim(1)), int(key.Stride(2)),
|
||||
n,
|
||||
)
|
||||
|
||||
value = c.values[0].View(ctx, 0,
|
||||
int(value.Dim(0)), int(value.Stride(1)),
|
||||
int(value.Dim(1)), int(value.Stride(2)),
|
||||
n,
|
||||
)
|
||||
|
||||
// TODO shift context if necessary
|
||||
|
||||
return key, value
|
||||
}
|
@ -9,7 +9,7 @@ import (
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type ModelParameters struct {
|
||||
@ -27,8 +27,8 @@ type AdapterParameters struct {
|
||||
} `json:"lora_parameters"`
|
||||
}
|
||||
|
||||
func (ModelParameters) KV(t *Tokenizer) llm.KV {
|
||||
kv := llm.KV{
|
||||
func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
||||
kv := ggml.KV{
|
||||
"general.file_type": uint32(1),
|
||||
"general.quantization_version": uint32(2),
|
||||
"tokenizer.ggml.pre": t.Pre,
|
||||
@ -54,7 +54,7 @@ func (ModelParameters) KV(t *Tokenizer) llm.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p AdapterParameters) KV() llm.KV {
|
||||
func (p AdapterParameters) KV() ggml.KV {
|
||||
var alpha float32
|
||||
if p.LoraParameters.Alpha == 0 {
|
||||
alpha = float32(p.Alpha)
|
||||
@ -62,7 +62,7 @@ func (p AdapterParameters) KV() llm.KV {
|
||||
alpha = p.LoraParameters.Alpha
|
||||
}
|
||||
|
||||
kv := llm.KV{
|
||||
kv := ggml.KV{
|
||||
"adapter.lora.alpha": alpha,
|
||||
"adapter.type": "lora",
|
||||
"general.file_type": uint32(1),
|
||||
@ -79,19 +79,19 @@ func (ModelParameters) specialTokenTypes() []string {
|
||||
}
|
||||
}
|
||||
|
||||
func (ModelParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
|
||||
return llm.WriteGGUF(ws, kv, ts)
|
||||
func (ModelParameters) writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
|
||||
return ggml.WriteGGUF(ws, kv, ts)
|
||||
}
|
||||
|
||||
func (AdapterParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
|
||||
return llm.WriteGGUF(ws, kv, ts)
|
||||
func (AdapterParameters) writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
|
||||
return ggml.WriteGGUF(ws, kv, ts)
|
||||
}
|
||||
|
||||
type ModelConverter interface {
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(*Tokenizer) llm.KV
|
||||
KV(*Tokenizer) ggml.KV
|
||||
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
|
||||
Tensors([]Tensor) []llm.Tensor
|
||||
Tensors([]Tensor) []ggml.Tensor
|
||||
// Replacements returns a list of string pairs to replace in tensor names.
|
||||
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
||||
Replacements() []string
|
||||
@ -99,7 +99,7 @@ type ModelConverter interface {
|
||||
// specialTokenTypes returns any special token types the model uses
|
||||
specialTokenTypes() []string
|
||||
// writeFile writes the model to the provided io.WriteSeeker
|
||||
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
|
||||
writeFile(io.WriteSeeker, ggml.KV, []ggml.Tensor) error
|
||||
}
|
||||
|
||||
type moreParser interface {
|
||||
@ -108,17 +108,17 @@ type moreParser interface {
|
||||
|
||||
type AdapterConverter interface {
|
||||
// KV maps parameters to LLM key-values
|
||||
KV(llm.KV) llm.KV
|
||||
KV(ggml.KV) ggml.KV
|
||||
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
|
||||
Tensors([]Tensor) []llm.Tensor
|
||||
Tensors([]Tensor) []ggml.Tensor
|
||||
// Replacements returns a list of string pairs to replace in tensor names.
|
||||
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
||||
Replacements() []string
|
||||
|
||||
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
|
||||
writeFile(io.WriteSeeker, ggml.KV, []ggml.Tensor) error
|
||||
}
|
||||
|
||||
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV llm.KV) error {
|
||||
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error {
|
||||
bts, err := fs.ReadFile(fsys, "adapter_config.json")
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type bertModel struct {
|
||||
@ -85,7 +85,7 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *bertModel) KV(t *Tokenizer) llm.KV {
|
||||
func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "bert"
|
||||
kv["bert.attention.causal"] = false
|
||||
@ -132,8 +132,8 @@ func (p *bertModel) KV(t *Tokenizer) llm.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *bertModel) Tensors(ts []Tensor) []llm.Tensor {
|
||||
var out []llm.Tensor
|
||||
func (p *bertModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
for _, t := range ts {
|
||||
if slices.Contains([]string{
|
||||
"embeddings.position_ids",
|
||||
@ -143,7 +143,7 @@ func (p *bertModel) Tensors(ts []Tensor) []llm.Tensor {
|
||||
continue
|
||||
}
|
||||
|
||||
out = append(out, llm.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
@ -3,7 +3,7 @@ package convert
|
||||
import (
|
||||
"cmp"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type commandrModel struct {
|
||||
@ -24,7 +24,7 @@ type commandrModel struct {
|
||||
|
||||
var _ ModelConverter = (*commandrModel)(nil)
|
||||
|
||||
func (p *commandrModel) KV(t *Tokenizer) llm.KV {
|
||||
func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "command-r"
|
||||
kv["general.name"] = "command-r"
|
||||
@ -43,10 +43,10 @@ func (p *commandrModel) KV(t *Tokenizer) llm.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *commandrModel) Tensors(ts []Tensor) []llm.Tensor {
|
||||
var out []llm.Tensor
|
||||
func (p *commandrModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
for _, t := range ts {
|
||||
out = append(out, llm.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
@ -6,7 +6,7 @@ import (
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type gemmaModel struct {
|
||||
@ -23,7 +23,7 @@ type gemmaModel struct {
|
||||
|
||||
var _ ModelConverter = (*gemmaModel)(nil)
|
||||
|
||||
func (p *gemmaModel) KV(t *Tokenizer) llm.KV {
|
||||
func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma"
|
||||
kv["gemma.context_length"] = p.MaxPositionEmbeddings
|
||||
@ -42,14 +42,14 @@ func (p *gemmaModel) KV(t *Tokenizer) llm.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *gemmaModel) Tensors(ts []Tensor) []llm.Tensor {
|
||||
var out []llm.Tensor
|
||||
func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
for _, t := range ts {
|
||||
if strings.HasSuffix(t.Name(), "_norm.weight") {
|
||||
t.SetRepacker(p.addOne)
|
||||
}
|
||||
|
||||
out = append(out, llm.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
@ -1,8 +1,6 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
import "github.com/ollama/ollama/fs/ggml"
|
||||
|
||||
type gemma2Model struct {
|
||||
gemmaModel
|
||||
@ -11,7 +9,7 @@ type gemma2Model struct {
|
||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||
}
|
||||
|
||||
func (p *gemma2Model) KV(t *Tokenizer) llm.KV {
|
||||
func (p *gemma2Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma2"
|
||||
kv["gemma2.context_length"] = p.MaxPositionEmbeddings
|
||||
|
@ -6,7 +6,7 @@ import (
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type gemma2Adapter struct {
|
||||
@ -15,14 +15,14 @@ type gemma2Adapter struct {
|
||||
|
||||
var _ AdapterConverter = (*gemma2Adapter)(nil)
|
||||
|
||||
func (p *gemma2Adapter) KV(baseKV llm.KV) llm.KV {
|
||||
func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
kv := p.AdapterParameters.KV()
|
||||
kv["general.architecture"] = "gemma2"
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *gemma2Adapter) Tensors(ts []Tensor) []llm.Tensor {
|
||||
var out []llm.Tensor
|
||||
func (p *gemma2Adapter) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
for _, t := range ts {
|
||||
shape := t.Shape()
|
||||
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
||||
@ -31,7 +31,7 @@ func (p *gemma2Adapter) Tensors(ts []Tensor) []llm.Tensor {
|
||||
t.SetRepacker(p.repack)
|
||||
}
|
||||
|
||||
out = append(out, llm.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type llamaModel struct {
|
||||
@ -46,7 +46,7 @@ type llamaModel struct {
|
||||
|
||||
var _ ModelConverter = (*llamaModel)(nil)
|
||||
|
||||
func (p *llamaModel) KV(t *Tokenizer) llm.KV {
|
||||
func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "llama"
|
||||
kv["llama.vocab_size"] = p.VocabSize
|
||||
@ -120,11 +120,11 @@ func (p *llamaModel) KV(t *Tokenizer) llm.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *llamaModel) Tensors(ts []Tensor) []llm.Tensor {
|
||||
var out []llm.Tensor
|
||||
func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
|
||||
if p.RopeScaling.factors != nil {
|
||||
out = append(out, llm.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: "rope_freqs.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
|
||||
@ -138,7 +138,7 @@ func (p *llamaModel) Tensors(ts []Tensor) []llm.Tensor {
|
||||
t.SetRepacker(p.repack)
|
||||
}
|
||||
|
||||
out = append(out, llm.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type llamaAdapter struct {
|
||||
@ -18,7 +18,7 @@ type llamaAdapter struct {
|
||||
|
||||
var _ AdapterConverter = (*llamaAdapter)(nil)
|
||||
|
||||
func (p *llamaAdapter) KV(baseKV llm.KV) llm.KV {
|
||||
func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
|
||||
kv := p.AdapterParameters.KV()
|
||||
kv["general.architecture"] = "llama"
|
||||
kv["llama.attention.head_count"] = baseKV["llama.attention.head_count"]
|
||||
@ -29,8 +29,8 @@ func (p *llamaAdapter) KV(baseKV llm.KV) llm.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *llamaAdapter) Tensors(ts []Tensor) []llm.Tensor {
|
||||
var out []llm.Tensor
|
||||
func (p *llamaAdapter) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
for _, t := range ts {
|
||||
shape := t.Shape()
|
||||
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
||||
@ -41,7 +41,7 @@ func (p *llamaAdapter) Tensors(ts []Tensor) []llm.Tensor {
|
||||
t.SetRepacker(p.repack)
|
||||
}
|
||||
|
||||
out = append(out, llm.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: shape,
|
||||
|
@ -6,7 +6,7 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type mixtralModel struct {
|
||||
@ -15,7 +15,7 @@ type mixtralModel struct {
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
}
|
||||
|
||||
func (p *mixtralModel) KV(t *Tokenizer) llm.KV {
|
||||
func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.llamaModel.KV(t)
|
||||
|
||||
if p.NumLocalExperts > 0 {
|
||||
@ -29,7 +29,7 @@ func (p *mixtralModel) KV(t *Tokenizer) llm.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *mixtralModel) Tensors(ts []Tensor) []llm.Tensor {
|
||||
func (p *mixtralModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
oldnew := []string{
|
||||
"model.layers", "blk",
|
||||
"w1", "ffn_gate_exps",
|
||||
@ -56,10 +56,10 @@ func (p *mixtralModel) Tensors(ts []Tensor) []llm.Tensor {
|
||||
return true
|
||||
})
|
||||
|
||||
var out []llm.Tensor
|
||||
var out []ggml.Tensor
|
||||
for n, e := range experts {
|
||||
// TODO(mxyng): sanity check experts
|
||||
out = append(out, llm.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: n,
|
||||
Kind: e[0].Kind(),
|
||||
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type phi3Model struct {
|
||||
@ -37,7 +37,7 @@ type phi3Model struct {
|
||||
|
||||
var _ ModelConverter = (*phi3Model)(nil)
|
||||
|
||||
func (p *phi3Model) KV(t *Tokenizer) llm.KV {
|
||||
func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "phi3"
|
||||
kv["phi3.context_length"] = p.MaxPositionEmbeddings
|
||||
@ -68,19 +68,19 @@ func (p *phi3Model) KV(t *Tokenizer) llm.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *phi3Model) Tensors(ts []Tensor) []llm.Tensor {
|
||||
func (p *phi3Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var addRopeFactors sync.Once
|
||||
|
||||
out := make([]llm.Tensor, 0, len(ts)+2)
|
||||
out := make([]ggml.Tensor, 0, len(ts)+2)
|
||||
for _, t := range ts {
|
||||
if strings.HasPrefix(t.Name(), "blk.0.") {
|
||||
addRopeFactors.Do(func() {
|
||||
out = append(out, llm.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: "rope_factors_long.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))},
|
||||
WriterTo: p.RopeScaling.LongFactor,
|
||||
}, llm.Tensor{
|
||||
}, ggml.Tensor{
|
||||
Name: "rope_factors_short.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))},
|
||||
@ -89,7 +89,7 @@ func (p *phi3Model) Tensors(ts []Tensor) []llm.Tensor {
|
||||
})
|
||||
}
|
||||
|
||||
out = append(out, llm.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
@ -1,6 +1,6 @@
|
||||
package convert
|
||||
|
||||
import "github.com/ollama/ollama/llm"
|
||||
import "github.com/ollama/ollama/fs/ggml"
|
||||
|
||||
type qwen2Model struct {
|
||||
ModelParameters
|
||||
@ -21,7 +21,7 @@ type qwen2Model struct {
|
||||
|
||||
var _ ModelConverter = (*qwen2Model)(nil)
|
||||
|
||||
func (q *qwen2Model) KV(t *Tokenizer) llm.KV {
|
||||
func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen2"
|
||||
kv["qwen2.block_count"] = q.HiddenLayers
|
||||
@ -45,10 +45,10 @@ func (q *qwen2Model) KV(t *Tokenizer) llm.KV {
|
||||
return kv
|
||||
}
|
||||
|
||||
func (q *qwen2Model) Tensors(ts []Tensor) []llm.Tensor {
|
||||
var out []llm.Tensor
|
||||
func (q *qwen2Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
for _, t := range ts {
|
||||
out = append(out, llm.Tensor{
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
|
@ -20,7 +20,7 @@ import (
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type tensorData struct {
|
||||
@ -29,7 +29,7 @@ type tensorData struct {
|
||||
Shape []int `json:"shape"`
|
||||
}
|
||||
|
||||
func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, *llm.Tensors) {
|
||||
func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.CreateTemp(t.TempDir(), "f16")
|
||||
@ -48,7 +48,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, *llm.Tensors) {
|
||||
}
|
||||
t.Cleanup(func() { r.Close() })
|
||||
|
||||
m, _, err := llm.DecodeGGML(r, math.MaxInt)
|
||||
m, _, err := ggml.Decode(r, math.MaxInt)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -60,7 +60,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, *llm.Tensors) {
|
||||
return r, m.KV(), m.Tensors()
|
||||
}
|
||||
|
||||
func generateResultsJSON(t *testing.T, f *os.File, kv llm.KV, tensors *llm.Tensors) map[string]string {
|
||||
func generateResultsJSON(t *testing.T, f *os.File, kv ggml.KV, tensors ggml.Tensors) map[string]string {
|
||||
actual := make(map[string]string)
|
||||
for k, v := range kv {
|
||||
if s, ok := v.(json.Marshaler); !ok {
|
||||
@ -75,7 +75,7 @@ func generateResultsJSON(t *testing.T, f *os.File, kv llm.KV, tensors *llm.Tenso
|
||||
}
|
||||
}
|
||||
|
||||
for _, tensor := range tensors.Items {
|
||||
for _, tensor := range tensors.Items() {
|
||||
sha256sum := sha256.New()
|
||||
sr := io.NewSectionReader(f, int64(tensors.Offset+tensor.Offset), int64(tensor.Size()))
|
||||
if _, err := io.Copy(sha256sum, sr); err != nil {
|
||||
@ -332,7 +332,7 @@ func TestConvertAdapter(t *testing.T) {
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
m, _, err := llm.DecodeGGML(r, math.MaxInt)
|
||||
m, _, err := ggml.Decode(r, math.MaxInt)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -1,15 +1,15 @@
|
||||
package llm
|
||||
package ggml
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/util/bufioutil"
|
||||
"github.com/ollama/ollama/fs/util/bufioutil"
|
||||
)
|
||||
|
||||
type GGML struct {
|
||||
@ -19,145 +19,166 @@ type GGML struct {
|
||||
|
||||
type model interface {
|
||||
KV() KV
|
||||
Tensors() *Tensors
|
||||
Tensors() Tensors
|
||||
}
|
||||
|
||||
type KV map[string]any
|
||||
|
||||
func (kv KV) u64(key string) uint64 {
|
||||
switch v := kv[key].(type) {
|
||||
case uint64:
|
||||
return v
|
||||
case uint32:
|
||||
return uint64(v)
|
||||
case float64:
|
||||
return uint64(v)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (kv KV) Architecture() string {
|
||||
if s, ok := kv["general.architecture"].(string); ok {
|
||||
return s
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
return kv.String("general.architecture", "unknown")
|
||||
}
|
||||
|
||||
func (kv KV) Kind() string {
|
||||
if s, ok := kv["general.type"].(string); ok {
|
||||
return s
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
return kv.String("general.type", "unknown")
|
||||
}
|
||||
|
||||
func (kv KV) ParameterCount() uint64 {
|
||||
return kv.u64("general.parameter_count")
|
||||
return keyValue[uint64](kv, "general.parameter_count")
|
||||
}
|
||||
|
||||
func (kv KV) FileType() fileType {
|
||||
if u64 := kv.u64("general.file_type"); u64 > 0 {
|
||||
return fileType(uint32(u64))
|
||||
if t := kv.Uint("general.file_type"); t > 0 {
|
||||
return fileType(t)
|
||||
}
|
||||
|
||||
return fileTypeUnknown
|
||||
}
|
||||
|
||||
func (kv KV) BlockCount() uint64 {
|
||||
return kv.u64(fmt.Sprintf("%s.block_count", kv.Architecture()))
|
||||
return uint64(kv.Uint("block_count"))
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingLength() uint64 {
|
||||
return uint64(kv.Uint("embedding_length"))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCount() uint64 {
|
||||
return kv.u64(fmt.Sprintf("%s.attention.head_count", kv.Architecture()))
|
||||
return uint64(kv.Uint("attention.head_count"))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKV() uint64 {
|
||||
if headCountKV := kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture())); headCountKV > 0 {
|
||||
return headCountKV
|
||||
}
|
||||
|
||||
return 1
|
||||
return uint64(kv.Uint("attention.head_count_kv", 1))
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCount() uint64 {
|
||||
if heads := kv.HeadCount(); heads > 0 {
|
||||
return kv.EmbeddingLength() / kv.HeadCount()
|
||||
return kv.EmbeddingLength() / heads
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCountK() uint64 {
|
||||
if k := kv.u64(fmt.Sprintf("%s.attention.key_length", kv.Architecture())); k > 0 {
|
||||
return k
|
||||
}
|
||||
|
||||
return kv.EmbeddingHeadCount()
|
||||
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCountV() uint64 {
|
||||
if v := kv.u64(fmt.Sprintf("%s.attention.value_length", kv.Architecture())); v > 0 {
|
||||
return v
|
||||
}
|
||||
|
||||
return kv.EmbeddingHeadCount()
|
||||
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
|
||||
}
|
||||
|
||||
func (kv KV) GQA() uint64 {
|
||||
return kv.HeadCount() / kv.HeadCountKV()
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingLength() uint64 {
|
||||
return kv.u64(fmt.Sprintf("%s.embedding_length", kv.Architecture()))
|
||||
}
|
||||
|
||||
func (kv KV) ContextLength() uint64 {
|
||||
return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
|
||||
return uint64(kv.Uint("context_length"))
|
||||
}
|
||||
|
||||
func (kv KV) ChatTemplate() string {
|
||||
s, _ := kv["tokenizer.chat_template"].(string)
|
||||
return kv.String("tokenizer.chat_template")
|
||||
}
|
||||
|
||||
func (kv KV) String(key string, defaultValue ...string) string {
|
||||
return keyValue(kv, key, append(defaultValue, "")...)
|
||||
}
|
||||
|
||||
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
|
||||
return keyValue(kv, key, append(defaultValue, 0)...)
|
||||
}
|
||||
|
||||
func (kv KV) Float(key string, defaultValue ...float32) float32 {
|
||||
return keyValue(kv, key, append(defaultValue, 0)...)
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
r := keyValue(kv, key, &array{})
|
||||
s := make([]string, r.size)
|
||||
for i := range r.size {
|
||||
s[i] = r.values[i].(string)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
type Tensors struct {
|
||||
Items []*Tensor
|
||||
Offset uint64
|
||||
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||
r := keyValue(kv, key, &array{})
|
||||
s := make([]uint32, r.size)
|
||||
for i := range r.size {
|
||||
s[i] = uint32(r.values[i].(int32))
|
||||
}
|
||||
|
||||
layers map[string]Layer
|
||||
layersOnce sync.Once
|
||||
return s
|
||||
}
|
||||
|
||||
func (ts *Tensors) Layers() map[string]Layer {
|
||||
ts.layersOnce.Do(func() {
|
||||
ts.layers = make(map[string]Layer)
|
||||
for _, t := range ts.Items {
|
||||
parts := strings.Split(t.Name, ".")
|
||||
if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
|
||||
if len(parts) > index+2 {
|
||||
// blk and mm should have a number after them, join it
|
||||
parts = append(
|
||||
[]string{strings.Join(parts[:index+2], ".")},
|
||||
parts[index+2:]...)
|
||||
}
|
||||
}
|
||||
func keyValue[T string | uint32 | uint64 | float32 | *array](kv KV, key string, defaultValue ...T) T {
|
||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||
key = kv.Architecture() + "." + key
|
||||
}
|
||||
|
||||
if _, ok := ts.layers[parts[0]]; !ok {
|
||||
ts.layers[parts[0]] = make(Layer)
|
||||
}
|
||||
if val, ok := kv[key]; ok {
|
||||
return val.(T)
|
||||
}
|
||||
|
||||
ts.layers[parts[0]][strings.Join(parts[1:], ".")] = t
|
||||
slog.Warn("key not found", "key", key, "default", defaultValue[0])
|
||||
return defaultValue[0]
|
||||
}
|
||||
|
||||
type Tensors struct {
|
||||
items []*Tensor
|
||||
Offset uint64
|
||||
}
|
||||
|
||||
func (s Tensors) Items(prefix ...string) []*Tensor {
|
||||
if len(prefix) == 0 {
|
||||
return s.items
|
||||
}
|
||||
|
||||
var items []*Tensor
|
||||
for _, t := range s.items {
|
||||
if strings.HasPrefix(t.Name, prefix[0]) {
|
||||
items = append(items, t)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return ts.layers
|
||||
return items
|
||||
}
|
||||
|
||||
func (ts Tensors) GroupLayers() map[string]Layer {
|
||||
layers := make(map[string]Layer)
|
||||
for _, t := range ts.items {
|
||||
parts := strings.Split(t.Name, ".")
|
||||
if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
|
||||
if len(parts) > index+2 {
|
||||
// blk and mm should have a number after them, join it
|
||||
parts = append(
|
||||
[]string{strings.Join(parts[:index+2], ".")},
|
||||
parts[index+2:]...)
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := layers[parts[0]]; !ok {
|
||||
layers[parts[0]] = make(Layer)
|
||||
}
|
||||
|
||||
layers[parts[0]][strings.Join(parts[1:], ".")] = t
|
||||
}
|
||||
|
||||
return layers
|
||||
}
|
||||
|
||||
type Layer map[string]*Tensor
|
||||
|
||||
func (l Layer) size() (size uint64) {
|
||||
func (l Layer) Size() (size uint64) {
|
||||
for _, t := range l {
|
||||
size += t.Size()
|
||||
}
|
||||
@ -255,8 +276,6 @@ func (t Tensor) typeSize() uint64 {
|
||||
return 8
|
||||
case 29: // IQ1_M
|
||||
return blockSize/8 + blockSize/16 + blockSize/32
|
||||
case 30: // BF16
|
||||
return 2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
@ -295,7 +314,7 @@ const (
|
||||
|
||||
var ErrUnsupportedFormat = errors.New("unsupported model format")
|
||||
|
||||
func DetectGGMLType(b []byte) string {
|
||||
func DetectContentType(b []byte) string {
|
||||
switch binary.LittleEndian.Uint32(b[:4]) {
|
||||
case FILE_MAGIC_GGML:
|
||||
return "ggml"
|
||||
@ -312,12 +331,12 @@ func DetectGGMLType(b []byte) string {
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeGGML decodes a GGML model from the given reader.
|
||||
// Decode decodes a GGML model from the given reader.
|
||||
//
|
||||
// It collects array values for arrays with a size less than or equal to
|
||||
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
|
||||
// the maxArraySize is negative, all arrays are collected.
|
||||
func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||
if maxArraySize == 0 {
|
||||
maxArraySize = 1024
|
||||
}
|
||||
@ -331,10 +350,6 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||
|
||||
var c container
|
||||
switch magic {
|
||||
case FILE_MAGIC_GGML, FILE_MAGIC_GGMF, FILE_MAGIC_GGJT:
|
||||
return nil, 0, ErrUnsupportedFormat
|
||||
case FILE_MAGIC_GGLA:
|
||||
c = &containerGGLA{}
|
||||
case FILE_MAGIC_GGUF_LE:
|
||||
c = &containerGGUF{ByteOrder: binary.LittleEndian, maxArraySize: maxArraySize}
|
||||
case FILE_MAGIC_GGUF_BE:
|
||||
@ -360,22 +375,22 @@ func DecodeGGML(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
||||
embedding := llm.KV().EmbeddingLength()
|
||||
heads := llm.KV().HeadCount()
|
||||
headsKV := llm.KV().HeadCountKV()
|
||||
vocab := uint64(llm.KV()["tokenizer.ggml.tokens"].(*array).size)
|
||||
func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
heads := f.KV().HeadCount()
|
||||
headsKV := f.KV().HeadCountKV()
|
||||
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size)
|
||||
|
||||
embeddingHeads := llm.KV().EmbeddingHeadCount()
|
||||
embeddingHeadsK := llm.KV().EmbeddingHeadCountK()
|
||||
embeddingHeadsV := llm.KV().EmbeddingHeadCountV()
|
||||
embeddingHeads := f.KV().EmbeddingHeadCount()
|
||||
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
|
||||
embeddingHeadsV := f.KV().EmbeddingHeadCountV()
|
||||
|
||||
layers := llm.Tensors().Layers()
|
||||
layers := f.Tensors().GroupLayers()
|
||||
|
||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||
kv = uint64(float64(context*llm.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
kv = uint64(float64(context*f.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
|
||||
switch llm.KV().Architecture() {
|
||||
switch f.KV().Architecture() {
|
||||
case "llama":
|
||||
fullOffload = max(
|
||||
4*batch*(1+4*embedding+context*(1+heads)),
|
||||
@ -390,7 +405,7 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
||||
|
||||
if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
|
||||
// mixtral 8x22b
|
||||
ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
|
||||
ff := uint64(f.KV()["llama.feed_forward_length"].(uint32))
|
||||
partialOffload = max(
|
||||
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
|
||||
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
|
||||
@ -407,11 +422,11 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
||||
case "mllama":
|
||||
var visionTokens, tiles uint64 = 1601, 4
|
||||
|
||||
if crossAttentionLayers, ok := llm.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
|
||||
if crossAttentionLayers, ok := f.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
|
||||
kv = headsKV *
|
||||
(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
|
||||
(2* // sizeof(float16)
|
||||
(llm.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
|
||||
(f.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
|
||||
context +
|
||||
4* // sizeof(float32)
|
||||
uint64(crossAttentionLayers.size)* // num cross attention layers
|
||||
@ -426,7 +441,7 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
||||
)
|
||||
|
||||
var ropeFreqsCount uint64
|
||||
if ropeFreqs, ok := llm.Tensors().Layers()["rope_freqs"]; ok {
|
||||
if ropeFreqs, ok := f.Tensors().GroupLayers()["rope_freqs"]; ok {
|
||||
if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
|
||||
ropeFreqsCount = ropeFreqsWeights.parameters()
|
||||
}
|
||||
@ -530,21 +545,20 @@ func (llm GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partia
|
||||
}
|
||||
|
||||
// SupportsKVCacheType checks if the requested cache type is supported
|
||||
func (ggml GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
validKVCacheTypes := []string{"f16", "q8_0", "q4_0"}
|
||||
return slices.Contains(validKVCacheTypes, cacheType)
|
||||
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
|
||||
}
|
||||
|
||||
// SupportsFlashAttention checks if the model supports flash attention
|
||||
func (ggml GGML) SupportsFlashAttention() bool {
|
||||
_, isEmbedding := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]
|
||||
func (f GGML) SupportsFlashAttention() bool {
|
||||
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]
|
||||
if isEmbedding {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check head counts match and are non-zero
|
||||
headCountK := ggml.KV().EmbeddingHeadCountK()
|
||||
headCountV := ggml.KV().EmbeddingHeadCountV()
|
||||
headCountK := f.KV().EmbeddingHeadCountK()
|
||||
headCountV := f.KV().EmbeddingHeadCountV()
|
||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||
}
|
||||
|
159
fs/ggml/ggml_test.go
Normal file
159
fs/ggml/ggml_test.go
Normal file
@ -0,0 +1,159 @@
|
||||
package ggml
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestTensorLayers(t *testing.T) {
|
||||
tensors := make(map[string]*Tensor)
|
||||
for _, name := range []string{
|
||||
"token_embd.weight",
|
||||
"blk.0.attn_k.weight",
|
||||
"blk.0.attn_output.weight",
|
||||
"blk.0.attn_q.weight",
|
||||
"blk.0.attn_v.weight",
|
||||
"blk.0.attn_norm.weight",
|
||||
"blk.0.ffn_down.weight",
|
||||
"blk.0.ffn_gate.weight",
|
||||
"blk.0.ffn_up.weight",
|
||||
"blk.0.ffn_norm.weight",
|
||||
"output_norm.weight",
|
||||
"mm.0.bias",
|
||||
"mm.0.weight",
|
||||
"v.blk.0.attn_k.weight",
|
||||
"v.blk.0.attn_output.weight",
|
||||
"v.blk.0.attn_q.weight",
|
||||
"v.blk.0.attn_v.weight",
|
||||
"v.blk.0.attn_norm.weight",
|
||||
"v.blk.0.ffn_down.weight",
|
||||
"v.blk.0.ffn_gate.weight",
|
||||
"v.blk.0.ffn_up.weight",
|
||||
"v.blk.0.ffn_norm.weight",
|
||||
"v.patch_embd.weight",
|
||||
"v.position_embd.gate",
|
||||
"v.position_embd.weight",
|
||||
} {
|
||||
tensors[name] = &Tensor{Name: name}
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
items []*Tensor
|
||||
want map[string]Layer
|
||||
}{
|
||||
{
|
||||
name: "text",
|
||||
items: slices.Collect(func(yield func(*Tensor) bool) {
|
||||
for k, v := range tensors {
|
||||
if !strings.HasPrefix(k, "mm.") && !strings.HasPrefix(k, "v.") {
|
||||
if !yield(v) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
want: map[string]Layer{
|
||||
"blk.0": {
|
||||
"attn_k.weight": tensors["blk.0.attn_k.weight"],
|
||||
"attn_q.weight": tensors["blk.0.attn_q.weight"],
|
||||
"attn_v.weight": tensors["blk.0.attn_v.weight"],
|
||||
"attn_output.weight": tensors["blk.0.attn_output.weight"],
|
||||
"attn_norm.weight": tensors["blk.0.attn_norm.weight"],
|
||||
"ffn_down.weight": tensors["blk.0.ffn_down.weight"],
|
||||
"ffn_gate.weight": tensors["blk.0.ffn_gate.weight"],
|
||||
"ffn_up.weight": tensors["blk.0.ffn_up.weight"],
|
||||
"ffn_norm.weight": tensors["blk.0.ffn_norm.weight"],
|
||||
},
|
||||
"token_embd": {"weight": tensors["token_embd.weight"]},
|
||||
"output_norm": {"weight": tensors["output_norm.weight"]},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "vision",
|
||||
items: slices.Collect(func(yield func(*Tensor) bool) {
|
||||
for k, v := range tensors {
|
||||
if strings.HasPrefix(k, "mm.") || strings.HasPrefix(k, "v.") {
|
||||
if !yield(v) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
want: map[string]Layer{
|
||||
"mm.0": {
|
||||
"bias": tensors["mm.0.bias"],
|
||||
"weight": tensors["mm.0.weight"],
|
||||
},
|
||||
"v.blk.0": {
|
||||
"attn_k.weight": tensors["v.blk.0.attn_k.weight"],
|
||||
"attn_q.weight": tensors["v.blk.0.attn_q.weight"],
|
||||
"attn_v.weight": tensors["v.blk.0.attn_v.weight"],
|
||||
"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
|
||||
"attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
|
||||
"ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
|
||||
"ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
|
||||
"ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
|
||||
"ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
|
||||
},
|
||||
"v": {
|
||||
"patch_embd.weight": tensors["v.patch_embd.weight"],
|
||||
"position_embd.gate": tensors["v.position_embd.gate"],
|
||||
"position_embd.weight": tensors["v.position_embd.weight"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "vision and text",
|
||||
items: slices.Collect(maps.Values(tensors)),
|
||||
want: map[string]Layer{
|
||||
"blk.0": {
|
||||
"attn_k.weight": tensors["blk.0.attn_k.weight"],
|
||||
"attn_q.weight": tensors["blk.0.attn_q.weight"],
|
||||
"attn_v.weight": tensors["blk.0.attn_v.weight"],
|
||||
"attn_output.weight": tensors["blk.0.attn_output.weight"],
|
||||
"attn_norm.weight": tensors["blk.0.attn_norm.weight"],
|
||||
"ffn_down.weight": tensors["blk.0.ffn_down.weight"],
|
||||
"ffn_gate.weight": tensors["blk.0.ffn_gate.weight"],
|
||||
"ffn_up.weight": tensors["blk.0.ffn_up.weight"],
|
||||
"ffn_norm.weight": tensors["blk.0.ffn_norm.weight"],
|
||||
},
|
||||
"token_embd": {"weight": tensors["token_embd.weight"]},
|
||||
"output_norm": {"weight": tensors["output_norm.weight"]},
|
||||
"mm.0": {
|
||||
"bias": tensors["mm.0.bias"],
|
||||
"weight": tensors["mm.0.weight"],
|
||||
},
|
||||
"v.blk.0": {
|
||||
"attn_k.weight": tensors["v.blk.0.attn_k.weight"],
|
||||
"attn_q.weight": tensors["v.blk.0.attn_q.weight"],
|
||||
"attn_v.weight": tensors["v.blk.0.attn_v.weight"],
|
||||
"attn_output.weight": tensors["v.blk.0.attn_output.weight"],
|
||||
"attn_norm.weight": tensors["v.blk.0.attn_norm.weight"],
|
||||
"ffn_down.weight": tensors["v.blk.0.ffn_down.weight"],
|
||||
"ffn_gate.weight": tensors["v.blk.0.ffn_gate.weight"],
|
||||
"ffn_up.weight": tensors["v.blk.0.ffn_up.weight"],
|
||||
"ffn_norm.weight": tensors["v.blk.0.ffn_norm.weight"],
|
||||
},
|
||||
"v": {
|
||||
"patch_embd.weight": tensors["v.patch_embd.weight"],
|
||||
"position_embd.gate": tensors["v.position_embd.gate"],
|
||||
"position_embd.weight": tensors["v.position_embd.weight"],
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := Tensors{items: tt.items}.GroupLayers()
|
||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||
t.Errorf("unexpected layers (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package llm
|
||||
package ggml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@ -8,10 +8,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
type containerGGUF struct {
|
||||
@ -110,9 +109,9 @@ func (llm *gguf) KV() KV {
|
||||
return llm.kv
|
||||
}
|
||||
|
||||
func (llm *gguf) Tensors() *Tensors {
|
||||
return &Tensors{
|
||||
Items: llm.tensors,
|
||||
func (llm *gguf) Tensors() Tensors {
|
||||
return Tensors{
|
||||
items: llm.tensors,
|
||||
Offset: llm.tensorOffset,
|
||||
}
|
||||
}
|
||||
@ -523,7 +522,7 @@ func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
|
||||
return err
|
||||
}
|
||||
|
||||
keys := maps.Keys(kv)
|
||||
keys := slices.Collect(maps.Keys(kv))
|
||||
slices.Sort(keys)
|
||||
|
||||
for _, key := range keys {
|
@ -1,4 +1,4 @@
|
||||
package llm
|
||||
package ggml
|
||||
|
||||
import "fmt"
|
||||
|
||||
@ -98,10 +98,10 @@ func ParseFileType(s string) (fileType, error) {
|
||||
return fileTypeIQ3_M, nil
|
||||
case "IQ2_S":
|
||||
return fileTypeIQ2_S, nil
|
||||
case "IQ4_XS":
|
||||
return fileTypeIQ4_XS, nil
|
||||
case "IQ2_M":
|
||||
return fileTypeIQ2_M, nil
|
||||
case "IQ4_XS":
|
||||
return fileTypeIQ4_XS, nil
|
||||
case "IQ1_M":
|
||||
return fileTypeIQ1_M, nil
|
||||
case "BF16":
|
149
llm/ggla.go
149
llm/ggla.go
@ -1,149 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"slices"
|
||||
)
|
||||
|
||||
type containerGGLA struct {
|
||||
version uint32
|
||||
}
|
||||
|
||||
func (c *containerGGLA) Name() string {
|
||||
return "ggla"
|
||||
}
|
||||
|
||||
func (c *containerGGLA) Decode(rs io.ReadSeeker) (model, error) {
|
||||
if err := binary.Read(rs, binary.LittleEndian, &c.version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch c.version {
|
||||
case 1:
|
||||
default:
|
||||
return nil, errors.New("invalid version")
|
||||
}
|
||||
|
||||
model := newGGLA(c)
|
||||
err := model.decode(rs)
|
||||
return model, err
|
||||
}
|
||||
|
||||
type ggla struct {
|
||||
*containerGGLA
|
||||
|
||||
kv KV
|
||||
tensors []*Tensor
|
||||
|
||||
tensorOffset uint64
|
||||
}
|
||||
|
||||
func newGGLA(container *containerGGLA) *ggla {
|
||||
return &ggla{
|
||||
containerGGLA: container,
|
||||
kv: make(KV),
|
||||
}
|
||||
}
|
||||
|
||||
func (llm *ggla) KV() KV {
|
||||
return llm.kv
|
||||
}
|
||||
|
||||
func (llm *ggla) Tensors() *Tensors {
|
||||
return &Tensors{
|
||||
Items: llm.tensors,
|
||||
Offset: llm.tensorOffset,
|
||||
}
|
||||
}
|
||||
|
||||
func (llm *ggla) decode(rs io.ReadSeeker) (retErr error) {
|
||||
var r uint32
|
||||
if err := binary.Read(rs, binary.LittleEndian, &r); err != nil {
|
||||
return err
|
||||
}
|
||||
llm.kv["r"] = r
|
||||
|
||||
var alpha uint32
|
||||
if err := binary.Read(rs, binary.LittleEndian, &alpha); err != nil {
|
||||
return err
|
||||
}
|
||||
llm.kv["alpha"] = alpha
|
||||
|
||||
offset, err := rs.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
llm.tensorOffset = uint64(offset)
|
||||
|
||||
for {
|
||||
var dims uint32
|
||||
if err := binary.Read(rs, binary.LittleEndian, &dims); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if errors.Is(retErr, io.EOF) {
|
||||
retErr = io.ErrUnexpectedEOF
|
||||
}
|
||||
}()
|
||||
|
||||
var namesize uint32
|
||||
if err := binary.Read(rs, binary.LittleEndian, &namesize); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var t Tensor
|
||||
if err := binary.Read(rs, binary.LittleEndian, &t.Kind); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.Shape = make([]uint64, dims)
|
||||
for i := 0; uint32(i) < dims; i++ {
|
||||
var shape32 uint32
|
||||
if err := binary.Read(rs, binary.LittleEndian, &shape32); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.Shape[i] = uint64(shape32)
|
||||
}
|
||||
|
||||
// ggla tensor shape is reversed
|
||||
// ref: https://github.com/ggerganov/llama.cpp/blob/29ae62d2ae163e2b68aa0ad3bf2ab4636de0c957/convert-lora-to-ggml.py#L44
|
||||
slices.Reverse(t.Shape)
|
||||
|
||||
name := make([]byte, namesize)
|
||||
if err := binary.Read(rs, binary.LittleEndian, &name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.Name = string(name)
|
||||
|
||||
offset, err := rs.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := rs.Seek((offset+31)&-32-offset, io.SeekCurrent); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
offset, err = rs.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.Offset = uint64(offset)
|
||||
|
||||
if _, err := rs.Seek(int64(t.Size()), io.SeekCurrent); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
llm.tensors = append(llm.tensors, &t)
|
||||
}
|
||||
}
|
@ -1 +0,0 @@
|
||||
package llm
|
@ -11,18 +11,19 @@ import (
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
// This algorithm looks for a complete fit to determine if we need to unload other models
|
||||
func PredictServerFit(allGpus discover.GpuInfoList, ggml *GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
|
||||
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
|
||||
// Split up the GPUs by type and try them
|
||||
var estimatedVRAM uint64
|
||||
for _, gpus := range allGpus.ByLibrary() {
|
||||
var layerCount int
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||
estimate := EstimateGPULayers(gpus, f, projectors, opts)
|
||||
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
|
||||
if opts.NumGPU < 0 {
|
||||
if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) {
|
||||
if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) {
|
||||
return true, estimatedVRAM
|
||||
}
|
||||
} else {
|
||||
@ -70,7 +71,7 @@ type MemoryEstimate struct {
|
||||
|
||||
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
|
||||
// The GPUs provided must all be the same Library
|
||||
func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string, opts api.Options) MemoryEstimate {
|
||||
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options) MemoryEstimate {
|
||||
// Graph size for a partial offload, applies to all GPUs
|
||||
var graphPartialOffload uint64
|
||||
|
||||
@ -115,33 +116,31 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
|
||||
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||
}
|
||||
|
||||
layers := ggml.Tensors().Layers()
|
||||
layers := f.Tensors().GroupLayers()
|
||||
// add one layer worth of memory as a buffer
|
||||
if blk0, ok := layers["blk.0"]; ok {
|
||||
layerSize = blk0.size()
|
||||
layerSize = blk0.Size()
|
||||
} else {
|
||||
slog.Warn("model missing blk.0 layer size")
|
||||
}
|
||||
|
||||
fa := envconfig.FlashAttention() &&
|
||||
discover.GetGPUInfo().FlashAttentionSupported() &&
|
||||
ggml.SupportsFlashAttention()
|
||||
|
||||
var kvct string
|
||||
if fa {
|
||||
if envconfig.FlashAttention() &&
|
||||
discover.GetGPUInfo().FlashAttentionSupported() &&
|
||||
f.SupportsFlashAttention() {
|
||||
requested := strings.ToLower(envconfig.KvCacheType())
|
||||
if requested != "" && ggml.SupportsKVCacheType(requested) {
|
||||
if requested != "" && f.SupportsKVCacheType(requested) {
|
||||
kvct = requested
|
||||
}
|
||||
}
|
||||
|
||||
kv, graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), kvct)
|
||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), kvct)
|
||||
|
||||
// KV is proportional to the number of layers
|
||||
layerSize += kv / ggml.KV().BlockCount()
|
||||
layerSize += kv / f.KV().BlockCount()
|
||||
|
||||
if graphPartialOffload == 0 {
|
||||
graphPartialOffload = ggml.KV().GQA() * kv / 6
|
||||
graphPartialOffload = f.KV().GQA() * kv / 6
|
||||
}
|
||||
if graphFullOffload == 0 {
|
||||
graphFullOffload = graphPartialOffload
|
||||
@ -156,12 +155,12 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
|
||||
}
|
||||
|
||||
if layer, ok := layers["output_norm"]; ok {
|
||||
memoryLayerOutput += layer.size()
|
||||
memoryLayerOutput += layer.Size()
|
||||
}
|
||||
if layer, ok := layers["output"]; ok {
|
||||
memoryLayerOutput += layer.size()
|
||||
memoryLayerOutput += layer.Size()
|
||||
} else if layer, ok := layers["token_embd"]; ok {
|
||||
memoryLayerOutput += layer.size()
|
||||
memoryLayerOutput += layer.Size()
|
||||
}
|
||||
|
||||
// Output layer handled at the end if we have space
|
||||
@ -211,11 +210,11 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
|
||||
}
|
||||
|
||||
// For all the layers, find where they can fit on the GPU(s)
|
||||
for i := range int(ggml.KV().BlockCount()) {
|
||||
for i := range int(f.KV().BlockCount()) {
|
||||
// Some models have inconsistent layer sizes
|
||||
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
||||
layerSize = blk.size()
|
||||
layerSize += kv / ggml.KV().BlockCount()
|
||||
layerSize = blk.Size()
|
||||
layerSize += kv / f.KV().BlockCount()
|
||||
}
|
||||
memoryWeights += layerSize
|
||||
|
||||
@ -238,10 +237,10 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
|
||||
}
|
||||
}
|
||||
}
|
||||
if layerCount >= int(ggml.KV().BlockCount()) {
|
||||
if layerCount >= int(f.KV().BlockCount()) {
|
||||
fullyLoaded = true
|
||||
} else {
|
||||
for i := layerCount; i < int(ggml.KV().BlockCount()); i++ {
|
||||
for i := layerCount; i < int(f.KV().BlockCount()); i++ {
|
||||
overflow += layerSize
|
||||
}
|
||||
}
|
||||
@ -259,7 +258,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
|
||||
}
|
||||
}
|
||||
|
||||
if layerCount < int(ggml.KV().BlockCount())+1 {
|
||||
if layerCount < int(f.KV().BlockCount())+1 {
|
||||
fullyLoaded = false
|
||||
overflow += memoryLayerOutput
|
||||
}
|
||||
@ -311,7 +310,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
|
||||
|
||||
inferenceLibrary: gpus[0].Library,
|
||||
layersRequested: opts.NumGPU,
|
||||
layersModel: int(ggml.KV().BlockCount()) + 1,
|
||||
layersModel: int(f.KV().BlockCount()) + 1,
|
||||
availableList: availableList,
|
||||
kv: kv,
|
||||
allocationsList: allocationsList,
|
||||
@ -339,22 +338,9 @@ func EstimateGPULayers(gpus []discover.GpuInfo, ggml *GGML, projectors []string,
|
||||
return estimate
|
||||
}
|
||||
|
||||
func (m MemoryEstimate) log() {
|
||||
overhead := envconfig.GpuOverhead()
|
||||
|
||||
log := slog.With()
|
||||
if m.projectorWeights > 0 {
|
||||
log = log.With(
|
||||
slog.Group(
|
||||
"projector",
|
||||
"weights", format.HumanBytes2(m.projectorWeights),
|
||||
"graph", format.HumanBytes2(m.projectorGraph),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
log.Info(
|
||||
"offload to "+m.inferenceLibrary,
|
||||
func (m MemoryEstimate) LogValue() slog.Value {
|
||||
attrs := []slog.Attr{
|
||||
slog.String("library", m.inferenceLibrary),
|
||||
slog.Group(
|
||||
"layers",
|
||||
// requested number of layers to offload
|
||||
@ -370,7 +356,7 @@ func (m MemoryEstimate) log() {
|
||||
"memory",
|
||||
// memory available by GPU for offloading
|
||||
"available", m.availableList,
|
||||
"gpu_overhead", format.HumanBytes2(overhead),
|
||||
"gpu_overhead", format.HumanBytes2(envconfig.GpuOverhead()),
|
||||
slog.Group(
|
||||
"required",
|
||||
// memory required for full offloading
|
||||
@ -399,7 +385,17 @@ func (m MemoryEstimate) log() {
|
||||
"partial", format.HumanBytes2(m.graphPartialOffload),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
if m.projectorWeights > 0 {
|
||||
attrs = append(attrs, slog.Group(
|
||||
"projector",
|
||||
"weights", format.HumanBytes2(m.projectorWeights),
|
||||
"graph", format.HumanBytes2(m.projectorGraph),
|
||||
))
|
||||
}
|
||||
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
func projectorMemoryRequirements(filename string) (weights, graphSize uint64) {
|
||||
@ -409,13 +405,13 @@ func projectorMemoryRequirements(filename string) (weights, graphSize uint64) {
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
ggml, _, err := DecodeGGML(file, 0)
|
||||
ggml, _, err := ggml.Decode(file, 0)
|
||||
if err != nil {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
for _, layer := range ggml.Tensors().Layers() {
|
||||
weights += layer.size()
|
||||
for _, layer := range ggml.Tensors().GroupLayers() {
|
||||
weights += layer.Size()
|
||||
}
|
||||
|
||||
switch arch := ggml.KV().Architecture(); arch {
|
||||
@ -435,7 +431,7 @@ func projectorMemoryRequirements(filename string) (weights, graphSize uint64) {
|
||||
headCount := kv("attention.head_count")
|
||||
|
||||
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
|
||||
if _, ok := ggml.Tensors().Layers()["v"]["class_embd"]; ok {
|
||||
if _, ok := ggml.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
||||
numPatches++
|
||||
}
|
||||
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
func TestEstimateGPULayers(t *testing.T) {
|
||||
@ -23,7 +24,7 @@ func TestEstimateGPULayers(t *testing.T) {
|
||||
defer f.Close()
|
||||
inputLayerCount := 5
|
||||
|
||||
tensors := []Tensor{
|
||||
tensors := []ggml.Tensor{
|
||||
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||
{Name: "blk.1.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||
{Name: "blk.2.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||
@ -32,7 +33,7 @@ func TestEstimateGPULayers(t *testing.T) {
|
||||
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||
}
|
||||
assert.Len(t, tensors, inputLayerCount+1)
|
||||
err = WriteGGUF(f, KV{
|
||||
err = ggml.WriteGGUF(f, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.context_length": uint32(32),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
|
@ -28,6 +28,7 @@ import (
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llama"
|
||||
)
|
||||
|
||||
@ -71,7 +72,7 @@ type llmServer struct {
|
||||
// It collects array values for arrays with a size less than or equal to
|
||||
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
|
||||
// the maxArraySize is negative, all arrays are collected.
|
||||
func LoadModel(model string, maxArraySize int) (*GGML, error) {
|
||||
func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
|
||||
if _, err := os.Stat(model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -82,21 +83,17 @@ func LoadModel(model string, maxArraySize int) (*GGML, error) {
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
ggml, _, err := DecodeGGML(f, maxArraySize)
|
||||
ggml, _, err := ggml.Decode(f, maxArraySize)
|
||||
return ggml, err
|
||||
}
|
||||
|
||||
// NewLlamaServer will run a server for the given GPUs
|
||||
// The gpu list must be a single family.
|
||||
func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
var systemTotalMemory uint64
|
||||
var systemFreeMemory uint64
|
||||
var systemSwapFreeMemory uint64
|
||||
|
||||
func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
systemInfo := discover.GetSystemInfo()
|
||||
systemTotalMemory = systemInfo.System.TotalMemory
|
||||
systemFreeMemory = systemInfo.System.FreeMemory
|
||||
systemSwapFreeMemory = systemInfo.System.FreeSwap
|
||||
systemTotalMemory := systemInfo.System.TotalMemory
|
||||
systemFreeMemory := systemInfo.System.FreeMemory
|
||||
systemSwapFreeMemory := systemInfo.System.FreeSwap
|
||||
slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
|
||||
|
||||
// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
|
||||
@ -104,7 +101,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
||||
gpus = discover.GetCPUInfo()
|
||||
}
|
||||
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||
estimate := EstimateGPULayers(gpus, f, projectors, opts)
|
||||
if len(gpus) > 1 || gpus[0].Library != "cpu" {
|
||||
switch {
|
||||
case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
|
||||
@ -130,7 +127,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
||||
}
|
||||
}
|
||||
|
||||
estimate.log()
|
||||
slog.Info("offload", "", estimate)
|
||||
|
||||
params := []string{
|
||||
"--model", model,
|
||||
@ -174,7 +171,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
||||
fa = false
|
||||
}
|
||||
|
||||
if fa && !ggml.SupportsFlashAttention() {
|
||||
if fa && !f.SupportsFlashAttention() {
|
||||
slog.Warn("flash attention enabled but not supported by model")
|
||||
fa = false
|
||||
}
|
||||
@ -187,7 +184,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
||||
|
||||
// Flash Attention also supports kv cache quantization
|
||||
// Enable if the requested and kv cache type is supported by the model
|
||||
if kvct != "" && ggml.SupportsKVCacheType(kvct) {
|
||||
if kvct != "" && f.SupportsKVCacheType(kvct) {
|
||||
params = append(params, "--kv-cache-type", kvct)
|
||||
} else {
|
||||
slog.Warn("kv cache type not supported by model", "type", kvct)
|
||||
@ -200,7 +197,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
||||
for _, g := range gpus {
|
||||
if g.Library == "metal" &&
|
||||
uint64(opts.NumGPU) > 0 &&
|
||||
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
|
||||
uint64(opts.NumGPU) < f.KV().BlockCount()+1 {
|
||||
opts.UseMMap = new(bool)
|
||||
*opts.UseMMap = false
|
||||
}
|
||||
@ -335,7 +332,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
|
||||
estimate: estimate,
|
||||
numParallel: numParallel,
|
||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||
totalLayers: ggml.KV().BlockCount() + 1,
|
||||
totalLayers: f.KV().BlockCount() + 1,
|
||||
gpus: gpus,
|
||||
done: make(chan error, 1),
|
||||
}
|
||||
|
196
ml/backend.go
Normal file
196
ml/backend.go
Normal file
@ -0,0 +1,196 @@
|
||||
package ml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Config interface {
|
||||
Architecture() string
|
||||
String(string, ...string) string
|
||||
Uint(string, ...uint32) uint32
|
||||
Float(string, ...float32) float32
|
||||
|
||||
Strings(string, ...[]string) []string
|
||||
Uints(string, ...[]uint32) []uint32
|
||||
}
|
||||
|
||||
type Backend interface {
|
||||
Config() Config
|
||||
Get(name string) Tensor
|
||||
NewContext() Context
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(*os.File) (Backend, error))
|
||||
|
||||
func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
|
||||
if _, ok := backends[name]; ok {
|
||||
panic("backend: backend already registered")
|
||||
}
|
||||
|
||||
backends[name] = f
|
||||
}
|
||||
|
||||
func NewBackend(f *os.File) (Backend, error) {
|
||||
if backend, ok := backends["ggml"]; ok {
|
||||
return backend(f)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported backend")
|
||||
}
|
||||
|
||||
type Context interface {
|
||||
Zeros(dtype DType, shape ...int) Tensor
|
||||
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
||||
FromIntSlice(s []int32, shape ...int) (Tensor, error)
|
||||
|
||||
Forward(Tensor)
|
||||
Compute(Tensor) Tensor
|
||||
Close() error
|
||||
}
|
||||
|
||||
type Tensor interface {
|
||||
Dim(n int) int64
|
||||
Stride(n int) int64
|
||||
|
||||
Shape() []int64
|
||||
DType() DType
|
||||
|
||||
Bytes() []byte
|
||||
Floats() []float32
|
||||
|
||||
Add(ctx Context, t2 Tensor) Tensor
|
||||
Mul(ctx Context, t2 Tensor) Tensor
|
||||
Mulmat(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
Softmax(ctx Context) Tensor
|
||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||
Scale(ctx Context, s float64) Tensor
|
||||
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
|
||||
|
||||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context) Tensor
|
||||
SILU(ctx Context) Tensor
|
||||
|
||||
Reshape(ctx Context, shape ...int64) Tensor
|
||||
View(ctx Context, offset int, shape ...int) Tensor
|
||||
Permute(ctx Context, shape ...int) Tensor
|
||||
Contiguous(ctx Context) Tensor
|
||||
|
||||
Pad(ctx Context, shape ...int64) Tensor
|
||||
Unpad(ctx Context, shape ...int64) Tensor
|
||||
|
||||
Stack(ctx Context, dim int, s ...Tensor) Tensor
|
||||
Concat(ctx Context, t2 Tensor, dim int) Tensor
|
||||
Rows(ctx Context, t2 Tensor) Tensor
|
||||
Copy(ctx Context, t2 Tensor) Tensor
|
||||
}
|
||||
|
||||
type number interface {
|
||||
~int | ~int8 | ~int16 | ~int32 | ~int64 |
|
||||
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||
~float32 | ~float64 |
|
||||
~complex64 | ~complex128
|
||||
}
|
||||
|
||||
func mul[T number](s ...T) T {
|
||||
p := T(1)
|
||||
for _, v := range s {
|
||||
p *= v
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
type DumpOptions struct {
|
||||
// Items is the number of elements to print at the beginning and end of each dimension.
|
||||
Items int64
|
||||
|
||||
// Precision is the number of decimal places to print. Applies to float32 and float64.
|
||||
Precision int
|
||||
}
|
||||
|
||||
func Dump(t Tensor, opts ...DumpOptions) string {
|
||||
if len(opts) < 1 {
|
||||
opts = append(opts, DumpOptions{
|
||||
Items: 3,
|
||||
Precision: 4,
|
||||
})
|
||||
}
|
||||
|
||||
switch t.DType() {
|
||||
case DTypeF32:
|
||||
return dump[[]float32](t, opts[0].Items, func(f float32) string {
|
||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||
})
|
||||
case DTypeI32:
|
||||
return dump[[]int32](t, opts[0].Items, func(i int32) string {
|
||||
return strconv.FormatInt(int64(i), 10)
|
||||
})
|
||||
default:
|
||||
return "<unsupported>"
|
||||
}
|
||||
}
|
||||
|
||||
func dump[S ~[]E, E number](t Tensor, items int64, fn func(E) string) string {
|
||||
bts := t.Bytes()
|
||||
if bts == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
|
||||
s := make(S, mul(t.Shape()...))
|
||||
if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
shape := t.Shape()
|
||||
|
||||
var sb strings.Builder
|
||||
var f func([]int64, int64)
|
||||
f = func(dims []int64, stride int64) {
|
||||
prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
|
||||
fmt.Fprint(&sb, "[")
|
||||
defer func() { fmt.Fprint(&sb, "]") }()
|
||||
for i := int64(0); i < dims[0]; i++ {
|
||||
if i >= items && i < dims[0]-items {
|
||||
fmt.Fprint(&sb, "..., ")
|
||||
// skip to next printable element
|
||||
skip := dims[0] - 2*items
|
||||
if len(dims) > 1 {
|
||||
stride += mul(append(dims[1:], skip)...)
|
||||
fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
|
||||
}
|
||||
i += skip - 1
|
||||
} else if len(dims) > 1 {
|
||||
f(dims[1:], stride)
|
||||
stride += mul(dims[1:]...)
|
||||
if i < dims[0]-1 {
|
||||
fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
|
||||
}
|
||||
} else {
|
||||
fmt.Fprint(&sb, fn(s[stride+i]))
|
||||
if i < dims[0]-1 {
|
||||
fmt.Fprint(&sb, ", ")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
f(shape, 0)
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
type DType int
|
||||
|
||||
const (
|
||||
DTypeF32 DType = iota
|
||||
DTypeI32
|
||||
DTypeOther
|
||||
)
|
5
ml/backend/backend.go
Normal file
5
ml/backend/backend.go
Normal file
@ -0,0 +1,5 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/ml/backend/ggml"
|
||||
)
|
580
ml/backend/ggml/ggml.go
Normal file
580
ml/backend/ggml/ggml.go
Normal file
@ -0,0 +1,580 @@
|
||||
package ggml
|
||||
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/ggml/include
|
||||
// #include <stdlib.h>
|
||||
// #include <stdint.h>
|
||||
// #include "ggml.h"
|
||||
// #include "ggml-cpu.h"
|
||||
// #include "ggml-backend.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
fs "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||
)
|
||||
|
||||
type device struct {
|
||||
d *C.struct_ggml_backend_device
|
||||
}
|
||||
|
||||
func (d device) LogValue() slog.Value {
|
||||
var free, total uint64
|
||||
C.ggml_backend_dev_memory(d.d, (*C.size_t)(&free), (*C.size_t)(&total))
|
||||
|
||||
kind := "unknown"
|
||||
switch C.ggml_backend_dev_type(d.d) {
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
|
||||
kind = "cpu"
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
|
||||
kind = "gpu"
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
|
||||
kind = "accel"
|
||||
}
|
||||
|
||||
return slog.GroupValue(
|
||||
slog.String("name", C.GoString(C.ggml_backend_dev_name(d.d))),
|
||||
slog.String("description", C.GoString(C.ggml_backend_dev_description(d.d))),
|
||||
slog.String("kind", kind),
|
||||
slog.String("free", format.HumanBytes2(free)),
|
||||
slog.String("total", format.HumanBytes2(total)),
|
||||
)
|
||||
}
|
||||
|
||||
var devices = sync.OnceValue(func() []device {
|
||||
ggml.OnceLoad()
|
||||
|
||||
s := make([]device, C.ggml_backend_dev_count())
|
||||
for i := range s {
|
||||
s[i] = device{C.ggml_backend_dev_get(C.size_t(i))}
|
||||
}
|
||||
|
||||
return s
|
||||
})
|
||||
|
||||
type Backend struct {
|
||||
meta *fs.GGML
|
||||
cpus, gpus []Context
|
||||
tensors map[string]*Context
|
||||
}
|
||||
|
||||
func New(r *os.File) (ml.Backend, error) {
|
||||
meta, n, err := fs.Decode(r, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.Info(
|
||||
"",
|
||||
"architecture", meta.KV().Architecture(),
|
||||
"file_type", meta.KV().FileType(),
|
||||
"name", meta.KV().String("general.name"),
|
||||
"description", meta.KV().String("general.description"),
|
||||
"num_tensors", len(meta.Tensors().Items()),
|
||||
"num_key_values", len(meta.KV()),
|
||||
)
|
||||
|
||||
var cpus, gpus []Context
|
||||
for _, d := range devices() {
|
||||
switch C.ggml_backend_dev_type(d.d) {
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_CPU,
|
||||
C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
|
||||
slog.Info("cpu", "device", d)
|
||||
cpus = append(cpus, Context{
|
||||
ctx: C.ggml_init(C.struct_ggml_init_params{
|
||||
mem_size: C.size_t(int(C.ggml_tensor_overhead()) * (len(meta.Tensors().Items()) + 1 + int(meta.KV().BlockCount())*2)),
|
||||
no_alloc: true,
|
||||
}),
|
||||
backend: C.ggml_backend_dev_init(d.d, nil),
|
||||
})
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
|
||||
slog.Info("gpu", "device", d)
|
||||
gpus = append(gpus, Context{
|
||||
ctx: C.ggml_init(C.struct_ggml_init_params{
|
||||
mem_size: C.size_t(int(C.ggml_tensor_overhead()) * (len(meta.Tensors().Items()) + 1 + int(meta.KV().BlockCount())*2)),
|
||||
no_alloc: true,
|
||||
}),
|
||||
backend: C.ggml_backend_dev_init(d.d, nil),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
ctxFunc := func(s []Context) (*Context, error) {
|
||||
for _, e := range s {
|
||||
return &e, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no devices available")
|
||||
}
|
||||
|
||||
tensors := make(map[*fs.Tensor]*Context, len(meta.Tensors().Items()))
|
||||
for _, t := range meta.Tensors().Items() {
|
||||
c, err := ctxFunc(append(gpus, cpus...))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func() {
|
||||
tt := C.ggml_new_tensor(c.ctx, t.Kind, C.int(len(t.Shape)), (*C.int64_t)(unsafe.Pointer(&t.Shape[0])))
|
||||
|
||||
cname := C.CString(t.Name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
C.ggml_set_name(tt, cname)
|
||||
|
||||
tensors[t] = c
|
||||
}()
|
||||
}
|
||||
|
||||
for _, b := range append(gpus, cpus...) {
|
||||
C.ggml_backend_alloc_ctx_tensors(b.ctx, b.backend)
|
||||
}
|
||||
|
||||
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
|
||||
|
||||
var g errgroup.Group
|
||||
for t, c := range tensors {
|
||||
g.Go(func() error {
|
||||
bts := make([]byte, t.Size())
|
||||
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if n != int(t.Size()) {
|
||||
return fmt.Errorf("expected %d bytes, got %d", t.Size(), n)
|
||||
}
|
||||
|
||||
cname := C.CString(t.Name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
|
||||
C.ggml_backend_tensor_set(C.ggml_get_tensor(c.ctx, cname), unsafe.Pointer(&bts[0]), 0, C.size_t(n))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Backend{
|
||||
meta: meta,
|
||||
cpus: cpus,
|
||||
gpus: gpus,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
ml.RegisterBackend("ggml", New)
|
||||
}
|
||||
|
||||
func (b *Backend) Config() ml.Config {
|
||||
return b.meta.KV()
|
||||
}
|
||||
|
||||
func (b *Backend) Get(name string) ml.Tensor {
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
|
||||
for _, c := range append(b.gpus, b.cpus...) {
|
||||
if t := C.ggml_get_tensor(c.ctx, cname); t != nil {
|
||||
return &Tensor{t: t}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Backend) NewContext() ml.Context {
|
||||
nodes := max(8192, len(b.meta.Tensors().Items())*5)
|
||||
bts := make([]byte, C.size_t(nodes)*C.ggml_tensor_overhead()+C.ggml_graph_overhead_custom(C.size_t(nodes), false))
|
||||
c := C.ggml_init(C.struct_ggml_init_params{
|
||||
mem_buffer: unsafe.Pointer(&bts[0]),
|
||||
mem_size: C.size_t(len(bts)),
|
||||
no_alloc: true,
|
||||
})
|
||||
|
||||
backends := make([]*C.struct_ggml_backend, len(b.gpus)+len(b.cpus))
|
||||
bufts := make([]*C.struct_ggml_backend_buffer_type, len(b.gpus)+len(b.cpus))
|
||||
for i, c := range append(b.gpus, b.cpus...) {
|
||||
backends[i] = c.backend
|
||||
bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend)
|
||||
}
|
||||
|
||||
return &Context{
|
||||
ctx: c,
|
||||
backend: backends[0],
|
||||
nodes: nodes,
|
||||
sched: C.ggml_backend_sched_new(
|
||||
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
|
||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
|
||||
C.int(len(backends)),
|
||||
C.size_t(nodes),
|
||||
true,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
type Context struct {
|
||||
ctx *C.struct_ggml_context
|
||||
backend *C.struct_ggml_backend
|
||||
|
||||
sched *C.struct_ggml_backend_sched
|
||||
graph *C.struct_ggml_cgraph
|
||||
nodes int
|
||||
}
|
||||
|
||||
func (c *Context) Forward(t ml.Tensor) {
|
||||
if c.graph == nil {
|
||||
c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.nodes), false)
|
||||
}
|
||||
|
||||
C.ggml_build_forward_expand(c.graph, t.(*Tensor).t)
|
||||
}
|
||||
|
||||
func (c *Context) Compute(t ml.Tensor) ml.Tensor {
|
||||
c.Forward(t)
|
||||
C.ggml_backend_sched_graph_compute_async(c.sched, c.graph)
|
||||
|
||||
backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t)
|
||||
|
||||
t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t))
|
||||
C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
return t
|
||||
}
|
||||
|
||||
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
if len(shape) < 1 || len(shape) > 4 {
|
||||
panic("unsupported number of dimensions")
|
||||
}
|
||||
|
||||
for _, dim := range shape {
|
||||
if dim < 1 {
|
||||
panic("invalid shape")
|
||||
}
|
||||
}
|
||||
|
||||
var t *C.struct_ggml_tensor
|
||||
switch dtype {
|
||||
case ml.DTypeF32:
|
||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), (*C.int64_t)(unsafe.Pointer(&shape[0])))
|
||||
case ml.DTypeI32:
|
||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), (*C.int64_t)(unsafe.Pointer(&shape[0])))
|
||||
default:
|
||||
panic("unsupported dtype")
|
||||
}
|
||||
|
||||
b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t))
|
||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||
C.ggml_set_zero(t)
|
||||
return &Tensor{t: t}
|
||||
}
|
||||
|
||||
func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
|
||||
n := len(s)
|
||||
for _, v := range shape {
|
||||
n /= v
|
||||
}
|
||||
|
||||
if n != 1 {
|
||||
return nil, fmt.Errorf("invalid shape %v for %d elements", shape, len(s))
|
||||
}
|
||||
|
||||
t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), (*C.int64_t)(unsafe.Pointer(&shape[0])))
|
||||
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
|
||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||
C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
|
||||
return &Tensor{t: t}, nil
|
||||
}
|
||||
|
||||
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||
return fromSlice(c, s, shape, C.GGML_TYPE_F32)
|
||||
}
|
||||
|
||||
func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
return fromSlice(c, s, shape, C.GGML_TYPE_I32)
|
||||
}
|
||||
|
||||
func (c *Context) Close() error {
|
||||
C.ggml_backend_sched_free(c.sched)
|
||||
C.ggml_free(c.ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
type Tensor struct {
|
||||
t *C.struct_ggml_tensor
|
||||
data []byte
|
||||
}
|
||||
|
||||
func (t *Tensor) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.String("name", C.GoString(C.ggml_get_name(t.t))),
|
||||
slog.String("type", C.GoString(C.ggml_type_name(t.t._type))),
|
||||
slog.Any("shape", t.Shape()),
|
||||
)
|
||||
}
|
||||
|
||||
func (t *Tensor) Dim(n int) int64 {
|
||||
return int64(t.t.ne[n])
|
||||
}
|
||||
|
||||
func (t *Tensor) Stride(n int) int64 {
|
||||
return int64(t.t.nb[n])
|
||||
}
|
||||
|
||||
func (t *Tensor) Shape() []int64 {
|
||||
shape := make([]int64, C.ggml_n_dims(t.t))
|
||||
for i := range shape {
|
||||
shape[i] = t.Dim(i)
|
||||
}
|
||||
|
||||
return shape
|
||||
}
|
||||
|
||||
func (t *Tensor) Bytes() []byte {
|
||||
if bts := C.ggml_get_data(t.t); bts != nil {
|
||||
return C.GoBytes(bts, C.int(C.ggml_nbytes(t.t)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tensor) Floats() (f32s []float32) {
|
||||
if t.data != nil {
|
||||
f32s = make([]float32, C.ggml_nelements(t.t))
|
||||
_ = binary.Read(bytes.NewReader(t.data), binary.LittleEndian, f32s)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (t *Tensor) DType() ml.DType {
|
||||
switch t.t._type {
|
||||
case C.GGML_TYPE_F32:
|
||||
return ml.DTypeF32
|
||||
case C.GGML_TYPE_I32:
|
||||
return ml.DTypeI32
|
||||
default:
|
||||
return ml.DTypeOther
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
|
||||
if len(s) > 0 {
|
||||
return t.Concat(ctx, s[0].Stack(ctx, dim, s[1:]...), dim)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
|
||||
return &Tensor{
|
||||
t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
t: C.ggml_cont(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||
tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
if b != nil {
|
||||
tt = tt.Add(ctx, b)
|
||||
}
|
||||
|
||||
return tt
|
||||
}
|
||||
|
||||
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
|
||||
return (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
}
|
||||
|
||||
func (t *Tensor) Pad(ctx ml.Context, shape ...int64) ml.Tensor {
|
||||
if len(shape) != 4 {
|
||||
panic("expected 4 dimensions")
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
if len(shape) != 4 {
|
||||
panic("expected 4 dimensions")
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Reshape(ctx ml.Context, shape ...int64) ml.Tensor {
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
return &Tensor{
|
||||
t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
|
||||
}
|
||||
case 2:
|
||||
return &Tensor{
|
||||
t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
|
||||
}
|
||||
case 3:
|
||||
return &Tensor{
|
||||
t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
|
||||
}
|
||||
case 4:
|
||||
return &Tensor{
|
||||
t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
|
||||
}
|
||||
default:
|
||||
panic("unsupported number of dimensions")
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
|
||||
return &Tensor{
|
||||
t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Unpad(ctx ml.Context, shape ...int64) ml.Tensor {
|
||||
if len(shape) != 4 {
|
||||
panic("expected 4 dimensions")
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
return &Tensor{
|
||||
t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
|
||||
}
|
||||
case 3:
|
||||
return &Tensor{
|
||||
t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
|
||||
C.int64_t(shape[0]), C.int64_t(shape[2]),
|
||||
C.size_t(shape[1]),
|
||||
C.size_t(offset)),
|
||||
}
|
||||
case 5:
|
||||
return &Tensor{
|
||||
t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
|
||||
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
|
||||
C.size_t(shape[1]), C.size_t(shape[3]),
|
||||
C.size_t(offset)),
|
||||
}
|
||||
case 7:
|
||||
return &Tensor{
|
||||
t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
|
||||
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
|
||||
C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
|
||||
C.size_t(offset)),
|
||||
}
|
||||
default:
|
||||
panic("unsupported number of dimensions")
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
ropeTypeNorm C.int = iota
|
||||
)
|
||||
|
||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
|
||||
if ropeFactors == nil {
|
||||
ropeFactors = &Tensor{}
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
t: C.ggml_rope_ext(
|
||||
ctx.(*Context).ctx, t.t, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
|
||||
C.int(ropeDim),
|
||||
131072, // YaRN n_ctx_train
|
||||
ropeTypeNorm, // ROPE_TYPE_NORM
|
||||
C.float(ropeBase),
|
||||
C.float(ropeScale),
|
||||
0., // YaRN ext_factor
|
||||
1., // YaRN attn_factor
|
||||
32., // YaRN beta_fast
|
||||
1., // YaRN beta_slow
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
return &Tensor{
|
||||
t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
|
||||
}
|
||||
}
|
@ -1,6 +0,0 @@
|
||||
//go:build debug
|
||||
|
||||
package ggml
|
||||
|
||||
// #cgo CPPFLAGS: -DOLLAMA_DEBUG
|
||||
import "C"
|
11
ml/nn/convolution.go
Normal file
11
ml/nn/convolution.go
Normal file
@ -0,0 +1,11 @@
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/ml"
|
||||
|
||||
type Conv2D struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
return m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1)
|
||||
}
|
11
ml/nn/embedding.go
Normal file
11
ml/nn/embedding.go
Normal file
@ -0,0 +1,11 @@
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/ml"
|
||||
|
||||
type Embedding struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
func (m *Embedding) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||
return m.Weight.Rows(ctx, hiddenState)
|
||||
}
|
17
ml/nn/linear.go
Normal file
17
ml/nn/linear.go
Normal file
@ -0,0 +1,17 @@
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/ml"
|
||||
|
||||
type Linear struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
t = m.Weight.Mulmat(ctx, t)
|
||||
if m.Bias != nil {
|
||||
t = t.Add(ctx, m.Bias)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
22
ml/nn/normalization.go
Normal file
22
ml/nn/normalization.go
Normal file
@ -0,0 +1,22 @@
|
||||
package nn
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type LayerNorm struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
|
||||
return t.LayerNorm(ctx, m.Weight, m.Bias, eps)
|
||||
}
|
||||
|
||||
type RMSNorm struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
}
|
||||
|
||||
func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
|
||||
return t.RMSNorm(ctx, m.Weight, eps)
|
||||
}
|
155
model/llama/model.go
Normal file
155
model/llama/model.go
Normal file
@ -0,0 +1,155 @@
|
||||
package llama
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||
hiddenSize, numHeads, numKVHeads int64
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*Options
|
||||
}
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
return &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
BOS: c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
EOS: c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
},
|
||||
),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: &Options{
|
||||
hiddenSize: int64(c.Uint("embedding_length")),
|
||||
numHeads: int64(c.Uint("attention.head_count")),
|
||||
numKVHeads: int64(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeDim: c.Uint("rope.dimension_count"),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type SelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
k, v = cache.Put(ctx, k, v, cache.Options)
|
||||
|
||||
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := k.Mulmat(ctx, q)
|
||||
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := v.Mulmat(ctx, kq)
|
||||
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *SelfAttention
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, opts.Cache.Sub(i), m.Options)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
hiddenState = m.Output.Forward(ctx, hiddenState)
|
||||
|
||||
outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hiddenState.Rows(ctx, outputs), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("llama", New)
|
||||
}
|
99
model/mllama/model.go
Normal file
99
model/mllama/model.go
Normal file
@ -0,0 +1,99 @@
|
||||
package mllama
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*TextModel
|
||||
|
||||
Projector *nn.Linear `gguf:"mm.0"`
|
||||
|
||||
ImageProcessor
|
||||
}
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
return &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
BOS: c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
EOS: c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
},
|
||||
),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
var crossAttentionStates ml.Tensor
|
||||
if opts.Images != nil {
|
||||
f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues, err := ctx.FromFloatSlice(f32s,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.numChannels,
|
||||
m.ImageProcessor.maxNumTiles,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions := make([]int32, 1601)
|
||||
for i := range positions {
|
||||
positions[i] = int32(i)
|
||||
}
|
||||
|
||||
positionIDs, err := ctx.FromIntSlice(positions, len(positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
crossAttentionStates = m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
|
||||
crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
|
||||
}
|
||||
|
||||
inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: attention mask, cross attention mask
|
||||
hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, opts.Cache)
|
||||
|
||||
outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hiddenState.Rows(ctx, outputs), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("mllama", New)
|
||||
}
|
225
model/mllama/model_text.go
Normal file
225
model/mllama/model_text.go
Normal file
@ -0,0 +1,225 @@
|
||||
package mllama
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type TextSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
key, value = cache.Put(ctx, key, value, cache.Options)
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
scores := key.Mulmat(ctx, query)
|
||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
|
||||
if mask != nil {
|
||||
scores = scores.Add(ctx, mask)
|
||||
}
|
||||
|
||||
scores = scores.Softmax(ctx)
|
||||
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type TextSelfAttentionDecoderLayer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *TextSelfAttention
|
||||
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *TextMLP
|
||||
}
|
||||
|
||||
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type TextCrossAttention struct {
|
||||
QueryNorm *nn.RMSNorm `gguf:"cross_attn_q_norm"`
|
||||
Query *nn.Linear `gguf:"cross_attn_q_proj"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"cross_attn_k_norm"`
|
||||
Key *nn.Linear `gguf:"cross_attn_k_proj"`
|
||||
Value *nn.Linear `gguf:"cross_attn_v_proj"`
|
||||
Output *nn.Linear `gguf:"cross_attn_o_proj"`
|
||||
}
|
||||
|
||||
func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
||||
|
||||
query := ca.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
|
||||
key := ca.Key.Forward(ctx, crossAttentionStates)
|
||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||
key = ca.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
value := ca.Value.Forward(ctx, crossAttentionStates)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||
|
||||
// TODO cache key, value
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
scores := key.Mulmat(ctx, query)
|
||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
scores = scores.Softmax(ctx)
|
||||
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return ca.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type TextCrossAttentionDecoderLayer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
CrossAttention *TextCrossAttention
|
||||
AttentionGate ml.Tensor `gguf:"cross_attn_attn_gate"`
|
||||
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *TextMLP
|
||||
MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
|
||||
}
|
||||
|
||||
func (d TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = d.CrossAttention.Forward(ctx, hiddenState, crossAttentionStates, cache, opts)
|
||||
hiddenState = hiddenState.Mul(ctx, d.AttentionGate.Tanh(ctx))
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
|
||||
hiddenState = hiddenState.Mul(ctx, d.MLPGate.Tanh(ctx))
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type TextDecoderLayer interface {
|
||||
Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor
|
||||
}
|
||||
|
||||
type TextDecoder struct {
|
||||
Layers []TextDecoderLayer
|
||||
}
|
||||
|
||||
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
|
||||
for i, layer := range d.Layers {
|
||||
if !slices.Contains(opts.crossAttentionLayers, uint32(i)) || crossAttentionStates != nil {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache.Sub(i), opts)
|
||||
}
|
||||
}
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
type TextModelOptions struct {
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||
|
||||
hiddenSize, numHeads, numKVHeads int64
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
|
||||
crossAttentionLayers []uint32
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Transformer *TextDecoder `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output"`
|
||||
|
||||
*TextModelOptions
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
|
||||
hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
func newTextModel(c ml.Config) *TextModel {
|
||||
var decoderLayers []TextDecoderLayer
|
||||
for i := range c.Uint("block_count") {
|
||||
var textDecoderLayer TextDecoderLayer
|
||||
if slices.Contains(c.Uints("attention.cross_attention_layers"), i) {
|
||||
textDecoderLayer = &TextCrossAttentionDecoderLayer{}
|
||||
} else {
|
||||
textDecoderLayer = &TextSelfAttentionDecoderLayer{}
|
||||
}
|
||||
|
||||
decoderLayers = append(decoderLayers, textDecoderLayer)
|
||||
}
|
||||
|
||||
return &TextModel{
|
||||
Transformer: &TextDecoder{Layers: decoderLayers},
|
||||
TextModelOptions: &TextModelOptions{
|
||||
hiddenSize: int64(c.Uint("embedding_length")),
|
||||
numHeads: int64(c.Uint("attention.head_count")),
|
||||
numKVHeads: int64(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeDim: c.Uint("rope.dimension_count"),
|
||||
crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
|
||||
},
|
||||
}
|
||||
}
|
234
model/mllama/model_vision.go
Normal file
234
model/mllama/model_vision.go
Normal file
@ -0,0 +1,234 @@
|
||||
package mllama
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
var batchSize int64 = 1
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_out"`
|
||||
|
||||
Gate ml.Tensor `gguf:"attn_gate"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
scores := key.Mulmat(ctx, query)
|
||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
scores = scores.Softmax(ctx)
|
||||
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize)
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
|
||||
hiddenState = sa.Output.Forward(ctx, attention)
|
||||
if sa.Gate != nil {
|
||||
hiddenState = hiddenState.Mul(ctx, sa.Gate)
|
||||
}
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
|
||||
Gate ml.Tensor `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
hiddenState = mlp.Down.Forward(ctx, hiddenState).GELU(ctx)
|
||||
hiddenState = mlp.Up.Forward(ctx, hiddenState)
|
||||
if mlp.Gate != nil {
|
||||
hiddenState = hiddenState.Mul(ctx, mlp.Gate)
|
||||
}
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
type VisionEncoderLayer struct {
|
||||
AttentionNorm *nn.LayerNorm `gguf:"ln1"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
|
||||
MLPNorm *nn.LayerNorm `gguf:"ln2"`
|
||||
MLP *VisionMLP
|
||||
}
|
||||
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
// self attention
|
||||
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
// feed forward
|
||||
hiddenState = e.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type VisionEncoder struct {
|
||||
Layers []VisionEncoderLayer
|
||||
}
|
||||
|
||||
func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermediateLayersIndices []uint32, opts *VisionModelOptions) (ml.Tensor, []ml.Tensor) {
|
||||
var intermediateHiddenStates []ml.Tensor
|
||||
for i, layer := range e.Layers {
|
||||
if slices.Contains(intermediateLayersIndices, uint32(i)) {
|
||||
intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int64{1}, hiddenState.Shape()...)...))
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, opts)
|
||||
}
|
||||
|
||||
return hiddenState, intermediateHiddenStates
|
||||
}
|
||||
|
||||
type PrecomputedAspectRatioEmbedding struct {
|
||||
Embedding *nn.Embedding
|
||||
Gate ml.Tensor `gguf:"gate"`
|
||||
}
|
||||
|
||||
func (e *PrecomputedAspectRatioEmbedding) Forward(ctx ml.Context, hiddenState ml.Tensor, aspectRatioIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
embeddings := e.Embedding.Forward(ctx, aspectRatioIDs)
|
||||
embeddings = embeddings.Reshape(ctx, opts.hiddenSize, 1, opts.numTiles)
|
||||
if e.Gate != nil {
|
||||
embeddings = embeddings.Mul(ctx, e.Gate)
|
||||
}
|
||||
|
||||
return hiddenState.Add(ctx, embeddings)
|
||||
}
|
||||
|
||||
type PrecomputedPositionEmbedding struct {
|
||||
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
|
||||
PositionEmbeddingGate ml.Tensor `gguf:"position_embd.gate"`
|
||||
|
||||
TilePositionEmbedding *nn.Embedding `gguf:"tile_position_embd"`
|
||||
TilePositionEmbeddingGate ml.Tensor `gguf:"tile_position_embd.gate"`
|
||||
}
|
||||
|
||||
func (e *PrecomputedPositionEmbedding) Forward(ctx ml.Context, hiddenState, positionIDs, aspectRatioIDs ml.Tensor, numPositions int64, opts *VisionModelOptions) ml.Tensor {
|
||||
positionEmbedding := e.PositionEmbedding.Forward(ctx, positionIDs)
|
||||
if e.PositionEmbeddingGate != nil {
|
||||
positionEmbedding = positionEmbedding.Mul(ctx, e.PositionEmbeddingGate)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, positionEmbedding)
|
||||
|
||||
tilePositionEmbedding := e.TilePositionEmbedding.Forward(ctx, aspectRatioIDs)
|
||||
tilePositionEmbedding = tilePositionEmbedding.Reshape(ctx, opts.hiddenSize, numPositions, opts.numTiles)
|
||||
if e.TilePositionEmbeddingGate != nil {
|
||||
tilePositionEmbedding = tilePositionEmbedding.Mul(ctx, e.TilePositionEmbeddingGate)
|
||||
}
|
||||
|
||||
return hiddenState.Add(ctx, tilePositionEmbedding)
|
||||
}
|
||||
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize, numHeads, numTiles int64
|
||||
imageSize, patchSize int
|
||||
eps float32
|
||||
|
||||
intermediateLayersIndices []uint32
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
PatchEmbeddings *nn.Conv2D `gguf:"patch_embd"`
|
||||
|
||||
PreTilePositionEmbedding *PrecomputedAspectRatioEmbedding `gguf:"pre_tile_position_embd"`
|
||||
PostTilePositionEmbedding *PrecomputedAspectRatioEmbedding `gguf:"post_tile_position_embd"`
|
||||
PositionEmbedding *PrecomputedPositionEmbedding
|
||||
|
||||
PreLayerNorm *nn.LayerNorm `gguf:"pre_ln"`
|
||||
PostLayerNorm *nn.LayerNorm `gguf:"post_ln"`
|
||||
ClassEmbedding ml.Tensor `gguf:"class_embd"`
|
||||
|
||||
Transformer *VisionEncoder `gguf:"blk"`
|
||||
GlobalTransformer *VisionEncoder `gguf:"global.blk"`
|
||||
|
||||
*VisionModelOptions
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRatioIDs ml.Tensor) ml.Tensor {
|
||||
numPatches := int64((m.imageSize / m.patchSize) * (m.imageSize / m.patchSize))
|
||||
numPositions := numPatches
|
||||
if m.ClassEmbedding != nil {
|
||||
numPositions++
|
||||
}
|
||||
|
||||
hiddenState := m.PatchEmbeddings.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize, m.numTiles)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
hiddenState = m.PreTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions)
|
||||
hiddenState = m.ClassEmbedding.Stack(ctx, 2, slices.Repeat([]ml.Tensor{m.ClassEmbedding}, int(m.numTiles)-1)...).Concat(ctx, hiddenState, 1)
|
||||
|
||||
hiddenState = m.PositionEmbedding.Forward(ctx, hiddenState, positionIDs, aspectRatioIDs, numPositions, m.VisionModelOptions)
|
||||
hiddenState = m.PreLayerNorm.Forward(ctx, hiddenState, m.eps)
|
||||
|
||||
numPaddingPatches := 8 - (hiddenState.Dim(1)%8)%8
|
||||
hiddenState = hiddenState.Pad(ctx, 0, numPaddingPatches, 0, 0)
|
||||
|
||||
hiddenState = hiddenState.Reshape(ctx, hiddenState.Dim(0), hiddenState.Dim(1)*hiddenState.Dim(2), batchSize)
|
||||
hiddenState, intermediateHiddenStates := m.Transformer.Forward(ctx, hiddenState, m.intermediateLayersIndices, m.VisionModelOptions)
|
||||
|
||||
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
|
||||
|
||||
hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
|
||||
hiddenState = m.PostTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions)
|
||||
|
||||
hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, m.numTiles*(numPositions+numPaddingPatches), batchSize)
|
||||
hiddenState, _ = m.GlobalTransformer.Forward(ctx, hiddenState, nil, m.VisionModelOptions)
|
||||
|
||||
hiddenStates := intermediateHiddenStates[0].Stack(ctx, 0, intermediateHiddenStates[1:]...)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, int64(len(intermediateHiddenStates))*m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
|
||||
hiddenStates = hiddenStates.Unpad(ctx, 0, numPaddingPatches, 0, 0)
|
||||
|
||||
hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
|
||||
hiddenState = hiddenState.Unpad(ctx, 0, numPaddingPatches, 0, 0)
|
||||
return hiddenState.Concat(ctx, hiddenStates, 0)
|
||||
}
|
||||
|
||||
func newVisionModel(c ml.Config) *VisionModel {
|
||||
return &VisionModel{
|
||||
Transformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count"))},
|
||||
GlobalTransformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.global.block_count"))},
|
||||
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: int64(c.Uint("vision.embedding_length")),
|
||||
numHeads: int64(c.Uint("vision.attention.head_count")),
|
||||
numTiles: int64(c.Uint("vision.max_num_tiles")),
|
||||
|
||||
imageSize: int(c.Uint("vision.image_size")),
|
||||
patchSize: int(c.Uint("vision.patch_size")),
|
||||
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon"),
|
||||
|
||||
intermediateLayersIndices: c.Uints("vision.intermediate_layers_indices"),
|
||||
},
|
||||
}
|
||||
}
|
240
model/mllama/process_image.go
Normal file
240
model/mllama/process_image.go
Normal file
@ -0,0 +1,240 @@
|
||||
package mllama
|
||||
|
||||
import (
|
||||
"image"
|
||||
"image/color"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"golang.org/x/image/draw"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type ImageProcessor struct {
|
||||
imageSize, numChannels, maxNumTiles int
|
||||
}
|
||||
|
||||
func newImageProcessor(c ml.Config) ImageProcessor {
|
||||
return ImageProcessor{
|
||||
imageSize: int(c.Uint("vision.image_size")),
|
||||
numChannels: int(c.Uint("vision.num_channels")),
|
||||
maxNumTiles: int(c.Uint("vision.max_num_tiles")),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) supportedAspectRatios(maxTiles int) []image.Point {
|
||||
ratios := []image.Point{}
|
||||
|
||||
for w := range maxTiles {
|
||||
for h := range maxTiles {
|
||||
if (w+1)*(h+1) <= maxTiles {
|
||||
ratios = append(ratios, image.Point{w + 1, h + 1})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ratios
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) clip(a, a_min, a_max int) int {
|
||||
if a < a_min {
|
||||
return a_min
|
||||
} else if a > a_max {
|
||||
return a_max
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) fitToCanvas(imageSize, canvasSize image.Point, tileSize int) image.Point {
|
||||
targetWidth := p.clip(imageSize.X, tileSize, canvasSize.X)
|
||||
targetHeight := p.clip(imageSize.Y, tileSize, canvasSize.Y)
|
||||
|
||||
scaleWidth := float64(targetWidth) / float64(imageSize.X)
|
||||
scaleHeight := float64(targetHeight) / float64(imageSize.Y)
|
||||
|
||||
var w, h int
|
||||
|
||||
if scaleWidth < scaleHeight {
|
||||
w = targetWidth
|
||||
h = min(int(math.Floor(float64(imageSize.Y)*scaleWidth)), targetHeight)
|
||||
} else {
|
||||
w = min(int(math.Floor(float64(imageSize.X)*scaleHeight)), targetWidth)
|
||||
h = targetHeight
|
||||
}
|
||||
|
||||
return image.Point{w, h}
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) optimalTiledCanvas(imageSize image.Point, maxImageTiles, tileSize int) image.Point {
|
||||
possibleTileArrangements := p.supportedAspectRatios(maxImageTiles)
|
||||
possibleCanvasSizes := []image.Point{}
|
||||
for _, pta := range possibleTileArrangements {
|
||||
possibleCanvasSizes = append(possibleCanvasSizes, image.Point{pta.X * tileSize, pta.Y * tileSize})
|
||||
}
|
||||
|
||||
scales := []float64{}
|
||||
|
||||
for _, pcs := range possibleCanvasSizes {
|
||||
scaleHeight := float64(pcs.Y) / float64(imageSize.Y)
|
||||
scaleWidth := float64(pcs.X) / float64(imageSize.X)
|
||||
|
||||
if scaleWidth > scaleHeight {
|
||||
scales = append(scales, scaleHeight)
|
||||
} else {
|
||||
scales = append(scales, scaleWidth)
|
||||
}
|
||||
}
|
||||
|
||||
var minUpscale float64
|
||||
var maxDownscale float64
|
||||
var upscale bool
|
||||
|
||||
for _, s := range scales {
|
||||
if s > 1.0 {
|
||||
upscale = true
|
||||
if minUpscale == 0 {
|
||||
minUpscale = s
|
||||
} else {
|
||||
minUpscale = math.Min(minUpscale, s)
|
||||
}
|
||||
} else {
|
||||
maxDownscale = math.Max(maxDownscale, s)
|
||||
}
|
||||
}
|
||||
|
||||
selectedScale := maxDownscale
|
||||
if upscale {
|
||||
selectedScale = minUpscale
|
||||
}
|
||||
|
||||
var selectedCanvas image.Point
|
||||
for n, pcs := range possibleCanvasSizes {
|
||||
if scales[n] == selectedScale {
|
||||
// choose the smallest possible canvas
|
||||
if selectedCanvas.X == 0 && selectedCanvas.Y == 0 {
|
||||
selectedCanvas = pcs
|
||||
} else if pcs.X*pcs.Y < selectedCanvas.X*selectedCanvas.Y {
|
||||
selectedCanvas = pcs
|
||||
}
|
||||
}
|
||||
}
|
||||
return selectedCanvas
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) splitToTiles(img image.Image, numTilesSize image.Point) []image.Image {
|
||||
b := img.Bounds()
|
||||
width := b.Max.X - b.Min.X
|
||||
height := b.Max.Y - b.Min.Y
|
||||
tileHeight := height / numTilesSize.Y
|
||||
tileWidth := width / numTilesSize.X
|
||||
|
||||
images := []image.Image{}
|
||||
|
||||
for h := range numTilesSize.Y {
|
||||
for w := range numTilesSize.X {
|
||||
rect := image.Rect(tileWidth*w, tileHeight*h, tileWidth*(w+1), tileHeight*(h+1))
|
||||
images = append(images, img.(interface {
|
||||
SubImage(image.Rectangle) image.Image
|
||||
}).SubImage(rect))
|
||||
}
|
||||
}
|
||||
|
||||
return images
|
||||
}
|
||||
|
||||
// remove the "alpha" channel by drawing over a prefilled image
|
||||
//
|
||||
// remove the "alpha" channel by drawing over a prefilled image
|
||||
//
|
||||
//nolint:unused
|
||||
func (p *ImageProcessor) compositeImage(img image.Image) image.Image {
|
||||
dst := image.NewRGBA(img.Bounds())
|
||||
|
||||
white := color.RGBA{255, 255, 255, 255}
|
||||
draw.Draw(dst, dst.Bounds(), &image.Uniform{white}, image.Point{}, draw.Src)
|
||||
draw.Draw(dst, dst.Bounds(), img, img.Bounds().Min, draw.Over)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) resize(img image.Image, outputSize image.Point, maxImageTiles int) (image.Image, image.Point) {
|
||||
b := img.Bounds()
|
||||
tileSize := outputSize.Y
|
||||
|
||||
canvasSize := p.optimalTiledCanvas(b.Max, maxImageTiles, tileSize)
|
||||
aspectRatio := image.Point{canvasSize.X / tileSize, canvasSize.Y / tileSize}
|
||||
newSize := p.fitToCanvas(b.Max, canvasSize, tileSize)
|
||||
|
||||
dst := image.NewRGBA(image.Rect(0, 0, newSize.X, newSize.Y))
|
||||
|
||||
// scaling choices:
|
||||
// NearestNeighbor fast, blocky output
|
||||
// ApproxBiLinear fast, medium quality
|
||||
// BiLinear slow, high quality
|
||||
// CatmullRom very slow, very high quality
|
||||
draw.BiLinear.Scale(dst, dst.Rect, img, b, draw.Over, nil)
|
||||
|
||||
return dst, aspectRatio
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) pad(img image.Image, outputSize, aspectRatio image.Point) image.Image {
|
||||
paddedSize := image.Point{
|
||||
X: outputSize.X * aspectRatio.X,
|
||||
Y: outputSize.Y * aspectRatio.Y,
|
||||
}
|
||||
|
||||
dst := image.NewRGBA(image.Rect(0, 0, paddedSize.X, paddedSize.Y))
|
||||
draw.Draw(dst, img.Bounds(), img, image.Point{0, 0}, draw.Over)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) pack(img image.Image, aspectRatio image.Point, mean, std [3]float32) []float32 {
|
||||
subImages := p.splitToTiles(img, aspectRatio)
|
||||
|
||||
var pixelVals []float32
|
||||
|
||||
for _, subImg := range subImages {
|
||||
bounds := subImg.Bounds()
|
||||
var rVals, gVals, bVals []float32
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||
c := subImg.At(x, y)
|
||||
r, g, b, _ := c.RGBA()
|
||||
rVal := float32(r>>8) / 255.0
|
||||
gVal := float32(g>>8) / 255.0
|
||||
bVal := float32(b>>8) / 255.0
|
||||
|
||||
rVal = (rVal - mean[0]) / std[0]
|
||||
gVal = (gVal - mean[1]) / std[1]
|
||||
bVal = (bVal - mean[2]) / std[2]
|
||||
|
||||
rVals = append(rVals, rVal)
|
||||
gVals = append(gVals, gVal)
|
||||
bVals = append(bVals, bVal)
|
||||
}
|
||||
}
|
||||
pixelVals = append(pixelVals, rVals...)
|
||||
pixelVals = append(pixelVals, gVals...)
|
||||
pixelVals = append(pixelVals, bVals...)
|
||||
}
|
||||
|
||||
return pixelVals
|
||||
}
|
||||
|
||||
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, int, error) {
|
||||
outputSize := image.Point{p.imageSize, p.imageSize}
|
||||
|
||||
// clip values
|
||||
mean := [3]float32{0.48145466, 0.4578275, 0.40821073}
|
||||
std := [3]float32{0.26862954, 0.26130258, 0.27577711}
|
||||
|
||||
newImage, aspectRatio := p.resize(img, outputSize, p.maxNumTiles)
|
||||
newImage = p.pad(newImage, outputSize, aspectRatio)
|
||||
|
||||
data := p.pack(newImage, aspectRatio, mean, std)
|
||||
aspectRatioIndex := slices.Index(p.supportedAspectRatios(p.maxNumTiles), aspectRatio) + 1
|
||||
return data, aspectRatioIndex, nil
|
||||
}
|
279
model/model.go
Normal file
279
model/model.go
Normal file
@ -0,0 +1,279 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"log/slog"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
_ "golang.org/x/image/bmp"
|
||||
_ "golang.org/x/image/tiff"
|
||||
_ "golang.org/x/image/webp"
|
||||
|
||||
"github.com/ollama/ollama/cache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
)
|
||||
|
||||
type Cache struct {
|
||||
cache.Cache
|
||||
cache.Options
|
||||
}
|
||||
|
||||
func (c Cache) Sub(i int) Cache {
|
||||
if c.Cache != nil {
|
||||
return Cache{
|
||||
Cache: c.Cache.Sub(i),
|
||||
Options: c.Options,
|
||||
}
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c Cache) Put(ctx ml.Context, key, value ml.Tensor, opts cache.Options) (ml.Tensor, ml.Tensor) {
|
||||
if c.Cache != nil {
|
||||
return c.Cache.Put(ctx, key, value, opts)
|
||||
}
|
||||
|
||||
return key, value
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
inputs []int32
|
||||
|
||||
Offset int
|
||||
|
||||
Images []image.Image
|
||||
|
||||
Cache
|
||||
}
|
||||
|
||||
func (opts Options) Inputs() []int32 {
|
||||
return opts.inputs[opts.Offset:]
|
||||
}
|
||||
|
||||
func (opts Options) Positions() []int32 {
|
||||
positions := make([]int32, len(opts.inputs)-opts.Offset)
|
||||
for i := range positions {
|
||||
positions[i] = int32(opts.Offset + i)
|
||||
}
|
||||
|
||||
return positions
|
||||
}
|
||||
|
||||
type OptionsFunc func(Model, *Options)
|
||||
|
||||
func WithInputIDs(ids []int32) OptionsFunc {
|
||||
return func(m Model, opts *Options) {
|
||||
opts.inputs = ids
|
||||
}
|
||||
}
|
||||
|
||||
func WithOffset(offset int) OptionsFunc {
|
||||
return func(m Model, opts *Options) {
|
||||
opts.Offset = offset
|
||||
opts.Cache.Position = offset
|
||||
}
|
||||
}
|
||||
|
||||
func WithImage(img image.Image) OptionsFunc {
|
||||
return func(m Model, opts *Options) {
|
||||
opts.Images = append(opts.Images, img)
|
||||
}
|
||||
}
|
||||
|
||||
func WithCache(c cache.Cache) OptionsFunc {
|
||||
return func(m Model, opts *Options) {
|
||||
opts.Cache = Cache{
|
||||
Cache: c,
|
||||
Options: cache.Options{
|
||||
Position: opts.Offset,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Base struct {
|
||||
b ml.Backend
|
||||
}
|
||||
|
||||
func (m *Base) Backend() ml.Backend {
|
||||
return m.b
|
||||
}
|
||||
|
||||
type Model interface {
|
||||
Forward(ml.Context, Options) (ml.Tensor, error)
|
||||
|
||||
Backend() ml.Backend
|
||||
}
|
||||
|
||||
var models = make(map[string]func(ml.Config) (Model, error))
|
||||
|
||||
func Register(name string, f func(ml.Config) (Model, error)) {
|
||||
if _, ok := models[name]; ok {
|
||||
panic("model: model already registered")
|
||||
}
|
||||
|
||||
models[name] = f
|
||||
}
|
||||
|
||||
func New(s string) (Model, error) {
|
||||
r, err := os.Open(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
b, err := ml.NewBackend(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
arch := b.Config().Architecture()
|
||||
f, ok := models[arch]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
||||
}
|
||||
|
||||
m, err := f(b.Config())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(m)
|
||||
v.Elem().Set(populateFields(b, v))
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
t := v.Type()
|
||||
if t.Kind() == reflect.Pointer {
|
||||
t, v = t.Elem(), v.Elem()
|
||||
}
|
||||
|
||||
if t.Kind() == reflect.Struct {
|
||||
allNil := true
|
||||
for i := range t.NumField() {
|
||||
tt := t.Field(i).Type
|
||||
vv := v.Field(i)
|
||||
if !vv.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
// make a copy
|
||||
tagsCopy := tags
|
||||
if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
|
||||
tagsCopy = append(tagsCopy, ParseTags(tag))
|
||||
}
|
||||
|
||||
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
|
||||
vv.Set(reflect.ValueOf(Base{b: b}))
|
||||
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
|
||||
var fn func([]Tag) [][]string
|
||||
fn = func(tags []Tag) (values [][]string) {
|
||||
if len(tags) < 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
values = [][]string{{tags[0].Name}}
|
||||
for _, alt := range tags[0].Alternate {
|
||||
values = append(values, []string{alt})
|
||||
}
|
||||
|
||||
for i, value := range values {
|
||||
for _, rest := range fn(tags[1:]) {
|
||||
value = append(value, rest...)
|
||||
}
|
||||
|
||||
values[i] = value
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
names := fn(tagsCopy)
|
||||
for _, name := range names {
|
||||
if tensor := b.Get(strings.Join(name, ".")); tensor != nil {
|
||||
slog.Debug("found tensor", "", tensor)
|
||||
vv.Set(reflect.ValueOf(tensor))
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if tt.Kind() == reflect.Pointer {
|
||||
vvv := vv.Elem()
|
||||
if vv.IsNil() {
|
||||
vvv = reflect.New(tt.Elem())
|
||||
}
|
||||
|
||||
if f := populateFields(b, vvv, tagsCopy...); f.CanAddr() {
|
||||
vv.Set(f.Addr())
|
||||
}
|
||||
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
|
||||
for i := range vv.Len() {
|
||||
vv.Index(i).Set(populateFields(b, vv.Index(i), append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
|
||||
}
|
||||
}
|
||||
|
||||
if !canNil(tt) || !vv.IsNil() {
|
||||
allNil = false
|
||||
}
|
||||
}
|
||||
|
||||
if allNil {
|
||||
return reflect.Zero(t)
|
||||
}
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
Name string
|
||||
Alternate []string
|
||||
}
|
||||
|
||||
func ParseTags(s string) (tag Tag) {
|
||||
parts := strings.Split(s, ",")
|
||||
if len(parts) > 0 {
|
||||
tag.Name = parts[0]
|
||||
|
||||
for _, part := range parts[1:] {
|
||||
if value, ok := strings.CutPrefix(part, "alt:"); ok {
|
||||
tag.Alternate = append(tag.Alternate, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func canNil(t reflect.Type) bool {
|
||||
return t.Kind() == reflect.Chan ||
|
||||
t.Kind() == reflect.Func ||
|
||||
t.Kind() == reflect.Interface ||
|
||||
t.Kind() == reflect.Map ||
|
||||
t.Kind() == reflect.Pointer ||
|
||||
t.Kind() == reflect.Slice
|
||||
}
|
||||
|
||||
func Forward(m Model, optsFuncs ...OptionsFunc) (ml.Tensor, error) {
|
||||
var opts Options
|
||||
for _, optsFunc := range optsFuncs {
|
||||
optsFunc(m, &opts)
|
||||
}
|
||||
|
||||
ctx := m.Backend().NewContext()
|
||||
t, err := m.Forward(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer ctx.Close()
|
||||
|
||||
return ctx.Compute(t), nil
|
||||
}
|
136
model/model_test.go
Normal file
136
model/model_test.go
Normal file
@ -0,0 +1,136 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/backend/ggml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
func TestParseTags(t *testing.T) {
|
||||
cases := []struct {
|
||||
value string
|
||||
want Tag
|
||||
}{
|
||||
{
|
||||
value: "output",
|
||||
want: Tag{
|
||||
Name: "output",
|
||||
},
|
||||
},
|
||||
{
|
||||
value: "output,alt:token_embd",
|
||||
want: Tag{
|
||||
Name: "output",
|
||||
Alternate: []string{
|
||||
"token_embd",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.value, func(t *testing.T) {
|
||||
got := ParseTags(tt.value)
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type fakeBackend struct {
|
||||
*ggml.Backend
|
||||
names []string
|
||||
}
|
||||
|
||||
type fakeTensor struct {
|
||||
*ggml.Tensor
|
||||
Name string
|
||||
}
|
||||
|
||||
func (m *fakeBackend) Get(name string) ml.Tensor {
|
||||
if slices.Contains(m.names, name) {
|
||||
return &fakeTensor{Name: name}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestPopulateFields(t *testing.T) {
|
||||
type fakeLayer struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_o"`
|
||||
}
|
||||
|
||||
type fakeModel struct {
|
||||
Input *nn.Embedding `gguf:"input"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output"`
|
||||
Layers [2]fakeLayer `gguf:"blk"`
|
||||
}
|
||||
|
||||
var m fakeModel
|
||||
v := reflect.ValueOf(&m)
|
||||
v.Elem().Set(populateFields(&fakeBackend{
|
||||
names: []string{
|
||||
"input.weight",
|
||||
"blk.0.attn_q.weight",
|
||||
"blk.0.attn_k.weight",
|
||||
"blk.0.attn_v.weight",
|
||||
"blk.1.attn_q.weight",
|
||||
"blk.1.attn_k.weight",
|
||||
"blk.1.attn_v.weight",
|
||||
"output_norm.weight",
|
||||
"output.weight",
|
||||
},
|
||||
}, v))
|
||||
|
||||
if diff := cmp.Diff(fakeModel{
|
||||
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
|
||||
OutputNorm: &nn.RMSNorm{Weight: &fakeTensor{Name: "output_norm.weight"}},
|
||||
Output: &nn.Linear{Weight: &fakeTensor{Name: "output.weight"}},
|
||||
Layers: [2]fakeLayer{
|
||||
{
|
||||
Query: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.attn_q.weight"}},
|
||||
Key: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.attn_k.weight"}},
|
||||
Value: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.attn_v.weight"}},
|
||||
},
|
||||
{
|
||||
Query: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.attn_q.weight"}},
|
||||
Key: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.attn_k.weight"}},
|
||||
Value: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.attn_v.weight"}},
|
||||
},
|
||||
},
|
||||
}, m); diff != "" {
|
||||
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopulateFieldsAlternateName(t *testing.T) {
|
||||
type fakeModel struct {
|
||||
Input *nn.Embedding `gguf:"input"`
|
||||
Output *nn.Linear `gguf:"output,alt:input"`
|
||||
}
|
||||
|
||||
m := fakeModel{}
|
||||
v := reflect.ValueOf(&m)
|
||||
v.Elem().Set(populateFields(&fakeBackend{
|
||||
names: []string{
|
||||
"input.weight",
|
||||
},
|
||||
}, v))
|
||||
|
||||
if diff := cmp.Diff(fakeModel{
|
||||
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
|
||||
Output: &nn.Linear{Weight: &fakeTensor{Name: "input.weight"}},
|
||||
}, m); diff != "" {
|
||||
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
313
model/process_text.go
Normal file
313
model/process_text.go
Normal file
@ -0,0 +1,313 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/dlclark/regexp2"
|
||||
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
|
||||
)
|
||||
|
||||
type Special int32
|
||||
|
||||
const (
|
||||
SpecialBOS Special = iota
|
||||
SpecialEOS
|
||||
)
|
||||
|
||||
type TextProcessor interface {
|
||||
Encode(string) ([]int32, error)
|
||||
Decode([]int32) (string, error)
|
||||
Is(uint32, Special) bool
|
||||
}
|
||||
|
||||
type Vocabulary struct {
|
||||
Values []string
|
||||
Types []uint32
|
||||
Scores []uint32
|
||||
Merges []string
|
||||
|
||||
BOS, EOS uint32
|
||||
|
||||
specialOnce sync.Once
|
||||
special []string
|
||||
|
||||
valuesOnce sync.Once
|
||||
values map[string]int32
|
||||
|
||||
mergeOnce sync.Once
|
||||
merge map[string]int32
|
||||
}
|
||||
|
||||
func (v *Vocabulary) Is(id uint32, special Special) bool {
|
||||
switch special {
|
||||
case SpecialBOS:
|
||||
return id == v.BOS
|
||||
case SpecialEOS:
|
||||
return id == v.EOS
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (v *Vocabulary) Encode(s string) int32 {
|
||||
v.valuesOnce.Do(func() {
|
||||
v.values = make(map[string]int32, len(v.Values))
|
||||
for i, value := range v.Values {
|
||||
v.values[value] = int32(i)
|
||||
}
|
||||
})
|
||||
|
||||
if id, ok := v.values[s]; ok {
|
||||
return id
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
func (v *Vocabulary) Decode(id int32) string {
|
||||
return v.Values[id]
|
||||
}
|
||||
|
||||
func (v *Vocabulary) SpecialVocabulary() []string {
|
||||
v.specialOnce.Do(func() {
|
||||
for i := range v.Values {
|
||||
if v.Types[i] == 3 {
|
||||
v.special = append(v.special, v.Values[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return v.special
|
||||
}
|
||||
|
||||
func (v *Vocabulary) Merge(left, right string) int {
|
||||
v.mergeOnce.Do(func() {
|
||||
v.merge = make(map[string]int32, len(v.Merges))
|
||||
for i, merge := range v.Merges {
|
||||
v.merge[merge] = int32(i)
|
||||
}
|
||||
})
|
||||
|
||||
if id, ok := v.merge[left+" "+right]; ok {
|
||||
return int(id)
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
type BytePairEncoding struct {
|
||||
pre *regexp2.Regexp
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
|
||||
return BytePairEncoding{
|
||||
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
|
||||
vocab: vocab,
|
||||
}
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Is(id uint32, special Special) bool {
|
||||
return bpe.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
for m, _ := bpe.pre.FindStringMatch(s); m != nil; m, _ = bpe.pre.FindNextMatch(m) {
|
||||
if !yield(m.String()) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fragment is a string fragment and their corresponding token IDs
|
||||
type fragment struct {
|
||||
value string
|
||||
ids []int32
|
||||
}
|
||||
|
||||
// pair is a pair of runes and its rank
|
||||
type pair struct {
|
||||
a, b int
|
||||
rank int
|
||||
value string
|
||||
}
|
||||
|
||||
type merge struct {
|
||||
p, n int
|
||||
runes []rune
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range bpe.vocab.SpecialVocabulary() {
|
||||
// TODO: process special tokens concurrently
|
||||
id := bpe.vocab.Encode(special)
|
||||
for i := 0; i < len(fragments); i++ {
|
||||
frag := fragments[i]
|
||||
if len(frag.ids) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var middle []fragment
|
||||
switch i := strings.Index(frag.value, special); {
|
||||
case i < 0:
|
||||
middle = append(middle, frag)
|
||||
case i > 0:
|
||||
middle = append(middle, fragment{value: frag.value[:i]})
|
||||
fallthrough
|
||||
default:
|
||||
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
||||
if rest := frag.value[i+len(special):]; rest != "" {
|
||||
middle = append(middle, fragment{value: rest})
|
||||
}
|
||||
}
|
||||
|
||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||
}
|
||||
}
|
||||
|
||||
var ids []int32
|
||||
for _, frag := range fragments {
|
||||
if len(frag.ids) > 0 {
|
||||
ids = append(ids, frag.ids...)
|
||||
slog.Debug("encoded", "text", frag.value, "ids", frag.ids, "special", true)
|
||||
continue
|
||||
}
|
||||
|
||||
for split := range bpe.split(frag.value) {
|
||||
// TODO: process splits concurrently
|
||||
var sb strings.Builder
|
||||
for _, b := range []byte(split) {
|
||||
r := rune(b)
|
||||
switch {
|
||||
case r == 0x00ad:
|
||||
r = 0x0143
|
||||
case r <= 0x0020:
|
||||
r = r + 0x0100
|
||||
case r >= 0x007e && r <= 0x00a0:
|
||||
r = r + 0x00a2
|
||||
}
|
||||
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
|
||||
// short circuit if the fragment is in the vocabulary
|
||||
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
slog.Debug("encoded", "text", sb.String(), "ids", []int32{id})
|
||||
continue
|
||||
}
|
||||
|
||||
runes := []rune(sb.String())
|
||||
merges := make([]merge, len(runes))
|
||||
for r := range runes {
|
||||
merges[r] = merge{
|
||||
p: r - 1,
|
||||
n: r + 1,
|
||||
runes: []rune{runes[r]},
|
||||
}
|
||||
}
|
||||
|
||||
pairwise := func(a, b int) *pair {
|
||||
if a < 0 || b >= len(runes) {
|
||||
return nil
|
||||
}
|
||||
|
||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||
rank := bpe.vocab.Merge(left, right)
|
||||
if rank < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &pair{
|
||||
a: a,
|
||||
b: b,
|
||||
rank: rank,
|
||||
value: left + right,
|
||||
}
|
||||
}
|
||||
|
||||
pairs := heap.NewWith(func(i, j *pair) int {
|
||||
return cmp.Compare(i.rank, j.rank)
|
||||
})
|
||||
|
||||
for i := range len(runes) - 1 {
|
||||
if pair := pairwise(i, i+1); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for !pairs.Empty() {
|
||||
pair, _ := pairs.Pop()
|
||||
|
||||
left, right := merges[pair.a], merges[pair.b]
|
||||
if len(left.runes) == 0 || len(right.runes) == 0 ||
|
||||
string(left.runes)+string(right.runes) != pair.value {
|
||||
continue
|
||||
}
|
||||
|
||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||
merges[pair.b].runes = nil
|
||||
|
||||
merges[pair.a].n = right.n
|
||||
if right.n < len(merges) {
|
||||
merges[right.n].p = pair.a
|
||||
}
|
||||
|
||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
|
||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||
pairs.Push(pair)
|
||||
}
|
||||
}
|
||||
|
||||
for _, merge := range merges {
|
||||
if len(merge.runes) > 0 {
|
||||
// TODO: handle the edge case where the rune isn't in the vocabulary
|
||||
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||
ids = append(ids, id)
|
||||
slog.Debug("encoded", "text", string(merge.runes), "ids", []int32{id})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
for _, r := range bpe.vocab.Decode(id) {
|
||||
switch {
|
||||
case r == 0x0100:
|
||||
// this produces 0x00 aka NULL
|
||||
continue
|
||||
case r == 0x0143:
|
||||
r = 0x00ad
|
||||
case r > 0x0100 && r <= 0x0120:
|
||||
r = r - 0x0100
|
||||
case r > 0x0120 && r <= 0x0142:
|
||||
r = r - 0x00a2
|
||||
}
|
||||
|
||||
// NOTE: not using WriteRune here because it writes the UTF-8
|
||||
// encoding of the rune which is _not_ what we want
|
||||
if err := sb.WriteByte(byte(r)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("decoded", "ids", ids, "text", sb.String())
|
||||
return sb.String(), nil
|
||||
}
|
254
model/process_text_test.go
Normal file
254
model/process_text_test.go
Normal file
@ -0,0 +1,254 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func llama(t testing.TB) BytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
vocab := make(map[string]int32)
|
||||
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
types := make([]uint32, len(vocab))
|
||||
tokens := make([]string, len(vocab))
|
||||
for token, id := range vocab {
|
||||
tokens[id] = token
|
||||
types[id] = 1
|
||||
}
|
||||
|
||||
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
|
||||
if _, ok := vocab[token]; !ok {
|
||||
tokens = append(tokens, token) //nolint:makezero
|
||||
types = append(types, 3) //nolint:makezero
|
||||
vocab[token] = int32(len(vocab))
|
||||
}
|
||||
}
|
||||
|
||||
f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
merges := make([]string, 0, 50000)
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
if !strings.HasPrefix(scanner.Text(), "#") {
|
||||
merges = append(merges, scanner.Text())
|
||||
}
|
||||
}
|
||||
|
||||
return NewBytePairEncoding(
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
&Vocabulary{
|
||||
Values: tokens,
|
||||
Types: types,
|
||||
Merges: merges,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func TestLlama(t *testing.T) {
|
||||
tokenizer := llama(t)
|
||||
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ids, err := tokenizer.Encode("hello world")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
|
||||
s, err := tokenizer.Decode([]int32{15339, 1917})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if s != "hello world" {
|
||||
t.Errorf("got %q, want hello world", s)
|
||||
}
|
||||
|
||||
ids, err = tokenizer.Encode("hello <|end_of_text|>")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("simple repeated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]int32{
|
||||
strings.Repeat("0", 1): {15},
|
||||
strings.Repeat("0", 2): {410},
|
||||
strings.Repeat("0", 3): {931},
|
||||
strings.Repeat("0", 4): {931, 15},
|
||||
strings.Repeat("0", 5): {931, 410},
|
||||
strings.Repeat("0", 6): {931, 931},
|
||||
strings.Repeat("0", 7): {931, 931, 15},
|
||||
strings.Repeat("0", 8): {931, 931, 410},
|
||||
strings.Repeat("0", 9): {931, 931, 931},
|
||||
strings.Repeat("0", 10): {931, 931, 931, 15},
|
||||
strings.Repeat("0", 11): {931, 931, 931, 410},
|
||||
strings.Repeat("0", 12): {931, 931, 931, 931},
|
||||
strings.Repeat("0", 13): {931, 931, 931, 931, 15},
|
||||
strings.Repeat("0", 14): {931, 931, 931, 931, 410},
|
||||
strings.Repeat("0", 15): {931, 931, 931, 931, 931},
|
||||
strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15},
|
||||
strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
ids, err := tokenizer.Encode(s)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, ids); diff != "" {
|
||||
t.Errorf("%q no match (-theirs +ours):\n%s", s, diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("basic roundtrip", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []string{
|
||||
"hello",
|
||||
"hello ",
|
||||
"hello ",
|
||||
" hello",
|
||||
" hello ",
|
||||
" hello ",
|
||||
"hello world",
|
||||
"请考试我的软件!12345",
|
||||
}
|
||||
|
||||
for _, want := range cases {
|
||||
ids, err := tokenizer.Encode(want)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if got, err := tokenizer.Decode(ids); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]int32{
|
||||
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
|
||||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
ids, err := tokenizer.Encode(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, ids); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("split", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := map[string][]string{
|
||||
"Hello World!": {"Hello", " World", "!"},
|
||||
"I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"},
|
||||
"In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"},
|
||||
"Hello!! ...world": {"Hello", "!!", " ...", "world"},
|
||||
"Hello World": {"Hello", " ", " World"},
|
||||
"Hello\nWorld": {"Hello", "\n", "World"},
|
||||
"Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
got := slices.Collect(tokenizer.split(s))
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
tokenizer := llama(b)
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for i := range 8 {
|
||||
n := min(int(math.Pow10(i)), len(bts))
|
||||
bts := bts[:n]
|
||||
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_, err := tokenizer.Encode(string(bts))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
|
||||
ids, err := tokenizer.Encode(string(bts))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
slices.Collect(tokenizer.split(string(bts)))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
128002
model/testdata/llama3.2/encoder.json
vendored
Normal file
128002
model/testdata/llama3.2/encoder.json
vendored
Normal file
File diff suppressed because it is too large
Load Diff
280147
model/testdata/llama3.2/vocab.bpe
vendored
Normal file
280147
model/testdata/llama3.2/vocab.bpe
vendored
Normal file
File diff suppressed because it is too large
Load Diff
63845
model/testdata/war-and-peace.txt
vendored
Normal file
63845
model/testdata/war-and-peace.txt
vendored
Normal file
File diff suppressed because it is too large
Load Diff
@ -19,7 +19,7 @@ import (
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
func TestParseFileFile(t *testing.T) {
|
||||
@ -769,7 +769,7 @@ func getSHA256Digest(t *testing.T, r io.Reader) (string, int64) {
|
||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
||||
}
|
||||
|
||||
func createBinFile(t *testing.T, kv map[string]any, ti []llm.Tensor) (string, string) {
|
||||
func createBinFile(t *testing.T, kv map[string]any, ti []ggml.Tensor) (string, string) {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.CreateTemp(t.TempDir(), "testbin.*.gguf")
|
||||
@ -778,7 +778,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []llm.Tensor) (string, st
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := llm.WriteGGUF(f, kv, ti); err != nil {
|
||||
if err := ggml.WriteGGUF(f, kv, ti); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Calculate sha256 of file
|
||||
|
13
sample/greedy.go
Normal file
13
sample/greedy.go
Normal file
@ -0,0 +1,13 @@
|
||||
package sample
|
||||
|
||||
import "gonum.org/v1/gonum/floats"
|
||||
|
||||
type greedy struct{}
|
||||
|
||||
func Greedy() Sampler {
|
||||
return greedy{}
|
||||
}
|
||||
|
||||
func (s greedy) Sample(t []float64) ([]float64, error) {
|
||||
return []float64{float64(floats.MaxIdx(t))}, nil
|
||||
}
|
74
sample/sample.go
Normal file
74
sample/sample.go
Normal file
@ -0,0 +1,74 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"gonum.org/v1/gonum/floats"
|
||||
"gonum.org/v1/gonum/stat/sampleuv"
|
||||
)
|
||||
|
||||
type Sampler interface {
|
||||
Sample([]float64) ([]float64, error)
|
||||
}
|
||||
|
||||
type Temperature float64
|
||||
|
||||
func (s Temperature) Sample(t []float64) ([]float64, error) {
|
||||
floats.Div(t, slices.Repeat([]float64{float64(s)}, len(t)))
|
||||
return t, nil
|
||||
}
|
||||
|
||||
type softmax struct{}
|
||||
|
||||
func Softmax() Sampler {
|
||||
return softmax{}
|
||||
}
|
||||
|
||||
func (softmax) Sample(t []float64) ([]float64, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
type TopK int
|
||||
|
||||
func (s TopK) Sample(t []float64) ([]float64, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
type TopP float32
|
||||
|
||||
func (s TopP) Sample(t []float64) ([]float64, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
type MinP float32
|
||||
|
||||
func (s MinP) Sample(t []float64) ([]float64, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
type weighed struct{}
|
||||
|
||||
func Weighed() Sampler {
|
||||
return weighed{}
|
||||
}
|
||||
|
||||
func (s weighed) Sample(t []float64) ([]float64, error) {
|
||||
w := sampleuv.NewWeighted(t, nil)
|
||||
if v, ok := w.Take(); ok {
|
||||
return []float64{float64(v)}, nil
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func Sample(floats []float64, samplers ...Sampler) ([]float64, error) {
|
||||
var err error
|
||||
for _, sampler := range samplers {
|
||||
floats, err = sampler.Sample(floats)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return floats, nil
|
||||
}
|
@ -21,8 +21,8 @@ import (
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@ -205,7 +205,7 @@ func detectModelTypeFromFiles(files map[string]string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
ct := llm.DetectGGMLType(buf)
|
||||
ct := ggml.DetectContentType(buf)
|
||||
if ct == "gguf" {
|
||||
return "gguf"
|
||||
}
|
||||
@ -271,11 +271,11 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ggml, _, err := llm.DecodeGGML(bin, 0)
|
||||
f, _, err := ggml.Decode(bin, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
layers := []*layerGGML{{layer, ggml}}
|
||||
layers := []*layerGGML{{layer, f}}
|
||||
|
||||
if !isAdapter {
|
||||
return detectChatTemplate(layers)
|
||||
@ -283,13 +283,13 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func kvFromLayers(baseLayers []*layerGGML) (llm.KV, error) {
|
||||
func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
|
||||
for _, l := range baseLayers {
|
||||
if l.GGML != nil {
|
||||
return l.KV(), nil
|
||||
}
|
||||
}
|
||||
return llm.KV{}, fmt.Errorf("no base model was found")
|
||||
return ggml.KV{}, fmt.Errorf("no base model was found")
|
||||
}
|
||||
|
||||
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, fn func(resp api.ProgressResponse)) (err error) {
|
||||
@ -306,7 +306,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
if layer.GGML != nil {
|
||||
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
|
||||
if quantType != "" && layer.GGML.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" {
|
||||
want, err := llm.ParseFileType(quantType)
|
||||
want, err := ggml.ParseFileType(quantType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -403,7 +403,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
|
||||
ft := layer.GGML.KV().FileType()
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantizeType)})
|
||||
|
||||
want, err := llm.ParseFileType(quantizeType)
|
||||
want, err := ggml.ParseFileType(quantizeType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -433,13 +433,13 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ggml, _, err := llm.DecodeGGML(temp, 0)
|
||||
f, _, err := ggml.Decode(temp, 0)
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &layerGGML{newLayer, ggml}, nil
|
||||
return &layerGGML{newLayer, f}, nil
|
||||
}
|
||||
|
||||
func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
|
||||
@ -475,7 +475,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
|
||||
var offset int64
|
||||
for offset < stat.Size() {
|
||||
ggml, n, err := llm.DecodeGGML(blob, 0)
|
||||
f, n, err := ggml.Decode(blob, 0)
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
} else if err != nil {
|
||||
@ -483,9 +483,9 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
}
|
||||
|
||||
mediatype := "application/vnd.ollama.image.model"
|
||||
if ggml.KV().Kind() == "adapter" {
|
||||
if f.KV().Kind() == "adapter" {
|
||||
mediatype = "application/vnd.ollama.image.adapter"
|
||||
} else if _, ok := ggml.KV()[fmt.Sprintf("%s.vision.block_count", ggml.KV().Architecture())]; ok || ggml.KV().Kind() == "projector" {
|
||||
} else if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok || f.KV().Kind() == "projector" {
|
||||
mediatype = "application/vnd.ollama.image.projector"
|
||||
}
|
||||
|
||||
@ -506,7 +506,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
}
|
||||
}
|
||||
|
||||
layers = append(layers, &layerGGML{layer, ggml})
|
||||
layers = append(layers, &layerGGML{layer, f})
|
||||
offset = n
|
||||
}
|
||||
|
||||
|
@ -23,7 +23,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@ -78,21 +78,21 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
|
||||
for _, cap := range caps {
|
||||
switch cap {
|
||||
case CapabilityCompletion:
|
||||
f, err := os.Open(m.ModelPath)
|
||||
r, err := os.Open(m.ModelPath)
|
||||
if err != nil {
|
||||
slog.Error("couldn't open model file", "error", err)
|
||||
continue
|
||||
}
|
||||
defer f.Close()
|
||||
defer r.Close()
|
||||
|
||||
// TODO(mxyng): decode the GGML into model to avoid doing this multiple times
|
||||
ggml, _, err := llm.DecodeGGML(f, 0)
|
||||
f, _, err := ggml.Decode(r, 0)
|
||||
if err != nil {
|
||||
slog.Error("couldn't decode ggml", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
|
||||
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
|
||||
errs = append(errs, errCapabilityCompletion)
|
||||
}
|
||||
case CapabilityTools:
|
||||
|
@ -15,7 +15,7 @@ import (
|
||||
"text/template/parse"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@ -24,7 +24,7 @@ var intermediateBlobs map[string]string = make(map[string]string)
|
||||
|
||||
type layerGGML struct {
|
||||
Layer
|
||||
*llm.GGML
|
||||
*ggml.GGML
|
||||
}
|
||||
|
||||
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
||||
@ -64,12 +64,12 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||
}
|
||||
defer blob.Close()
|
||||
|
||||
ggml, _, err := llm.DecodeGGML(blob, 0)
|
||||
f, _, err := ggml.Decode(blob, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers = append(layers, &layerGGML{layer, ggml})
|
||||
layers = append(layers, &layerGGML{layer, f})
|
||||
default:
|
||||
layers = append(layers, &layerGGML{layer, nil})
|
||||
}
|
||||
@ -118,7 +118,7 @@ func detectContentType(r io.Reader) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if contentType := llm.DetectGGMLType(b.Bytes()); contentType != "" {
|
||||
if contentType := ggml.DetectContentType(b.Bytes()); contentType != "" {
|
||||
return contentType, nil
|
||||
}
|
||||
|
||||
|
@ -30,6 +30,7 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/model/mllama"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@ -860,7 +861,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func getKVData(digest string, verbose bool) (llm.KV, error) {
|
||||
func getKVData(digest string, verbose bool) (ggml.KV, error) {
|
||||
maxArraySize := 0
|
||||
if verbose {
|
||||
maxArraySize = -1
|
||||
|
@ -19,12 +19,12 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
var stream bool = false
|
||||
|
||||
func createBinFile(t *testing.T, kv map[string]any, ti []llm.Tensor) (string, string) {
|
||||
func createBinFile(t *testing.T, kv map[string]any, ti []ggml.Tensor) (string, string) {
|
||||
t.Helper()
|
||||
t.Setenv("OLLAMA_MODELS", cmp.Or(os.Getenv("OLLAMA_MODELS"), t.TempDir()))
|
||||
|
||||
@ -36,7 +36,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []llm.Tensor) (string, st
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := llm.WriteGGUF(f, kv, ti); err != nil {
|
||||
if err := ggml.WriteGGUF(f, kv, ti); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Calculate sha256 of file
|
||||
@ -672,7 +672,7 @@ func TestCreateDetectTemplate(t *testing.T) {
|
||||
var s Server
|
||||
|
||||
t.Run("matched", func(t *testing.T) {
|
||||
_, digest := createBinFile(t, llm.KV{
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"tokenizer.chat_template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
|
||||
}, nil)
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
@ -45,8 +46,8 @@ func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error
|
||||
return
|
||||
}
|
||||
|
||||
func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
||||
return func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *ggml.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
||||
return func(_ discover.GpuInfoList, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) {
|
||||
return mock, nil
|
||||
}
|
||||
}
|
||||
@ -76,7 +77,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
@ -88,7 +89,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
|
||||
go s.sched.Run(context.TODO())
|
||||
|
||||
_, digest := createBinFile(t, llm.KV{
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
@ -98,7 +99,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []llm.Tensor{
|
||||
}, []ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
@ -154,10 +155,10 @@ func TestGenerateChat(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("missing capabilities chat", func(t *testing.T) {
|
||||
_, digest := createBinFile(t, llm.KV{
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "bert",
|
||||
"bert.pooling_type": uint32(0),
|
||||
}, []llm.Tensor{})
|
||||
}, []ggml.Tensor{})
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "bert",
|
||||
Files: map[string]string{"bert.gguf": digest},
|
||||
@ -612,7 +613,7 @@ func TestGenerate(t *testing.T) {
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
@ -624,7 +625,7 @@ func TestGenerate(t *testing.T) {
|
||||
|
||||
go s.sched.Run(context.TODO())
|
||||
|
||||
_, digest := createBinFile(t, llm.KV{
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
@ -634,7 +635,7 @@ func TestGenerate(t *testing.T) {
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []llm.Tensor{
|
||||
}, []ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
@ -686,10 +687,10 @@ func TestGenerate(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("missing capabilities generate", func(t *testing.T) {
|
||||
_, digest := createBinFile(t, llm.KV{
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "bert",
|
||||
"bert.pooling_type": uint32(0),
|
||||
}, []llm.Tensor{})
|
||||
}, []ggml.Tensor{})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "bert",
|
||||
|
@ -21,7 +21,7 @@ import (
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
@ -654,8 +654,8 @@ func TestShow(t *testing.T) {
|
||||
|
||||
var s Server
|
||||
|
||||
_, digest1 := createBinFile(t, llm.KV{"general.architecture": "test"}, nil)
|
||||
_, digest2 := createBinFile(t, llm.KV{"general.type": "projector", "general.architecture": "clip"}, nil)
|
||||
_, digest1 := createBinFile(t, ggml.KV{"general.architecture": "test"}, nil)
|
||||
_, digest2 := createBinFile(t, ggml.KV{"general.type": "projector", "general.architecture": "clip"}, nil)
|
||||
|
||||
createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "show-model",
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
@ -41,8 +42,8 @@ type Scheduler struct {
|
||||
loaded map[string]*runnerRef
|
||||
loadedMu sync.Mutex
|
||||
|
||||
loadFn func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int)
|
||||
newServerFn func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
|
||||
loadFn func(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, numParallel int)
|
||||
newServerFn func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
|
||||
getGpuFn func() discover.GpuInfoList
|
||||
getCpuFn func() discover.GpuInfoList
|
||||
reschedDelay time.Duration
|
||||
@ -409,7 +410,7 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
|
||||
func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, numParallel int) {
|
||||
if numParallel < 1 {
|
||||
numParallel = 1
|
||||
}
|
||||
@ -417,12 +418,12 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoL
|
||||
if req.sessionDuration != nil {
|
||||
sessionDuration = req.sessionDuration.Duration
|
||||
}
|
||||
llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
|
||||
llama, err := s.newServerFn(gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
|
||||
if err != nil {
|
||||
// some older models are not compatible with newer versions of llama.cpp
|
||||
// show a generalized compatibility error until there is a better way to
|
||||
// check for model compatibility
|
||||
if errors.Is(err, llm.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") {
|
||||
if errors.Is(err, ggml.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") {
|
||||
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
|
||||
}
|
||||
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
|
||||
@ -685,7 +686,7 @@ func (a ByDuration) Less(i, j int) bool {
|
||||
// If the model can not be fit fully within the available GPU(s) nil is returned
|
||||
// If numParallel is <= 0, this will attempt try to optimize parallelism based on available VRAM, and adjust
|
||||
// opts.NumCtx accordingly
|
||||
func pickBestFullFitByLibrary(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel *int) discover.GpuInfoList {
|
||||
func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, numParallel *int) discover.GpuInfoList {
|
||||
var estimatedVRAM uint64
|
||||
|
||||
var numParallelToTry []int
|
||||
@ -710,7 +711,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, ggml *llm.GGML, gpus discover.Gpu
|
||||
req.opts.NumCtx = req.origNumCtx * p
|
||||
if !envconfig.SchedSpread() {
|
||||
for _, g := range sgl {
|
||||
if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
||||
if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
||||
slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
|
||||
*numParallel = p
|
||||
return []discover.GpuInfo{g}
|
||||
@ -726,7 +727,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, ggml *llm.GGML, gpus discover.Gpu
|
||||
// Now try all the GPUs
|
||||
for _, p := range numParallelToTry {
|
||||
req.opts.NumCtx = req.origNumCtx * p
|
||||
if ok, estimatedVRAM = llm.PredictServerFit(sgl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
||||
if ok, estimatedVRAM = llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
||||
slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM))
|
||||
*numParallel = p
|
||||
return sgl
|
||||
@ -737,7 +738,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, ggml *llm.GGML, gpus discover.Gpu
|
||||
}
|
||||
|
||||
// If multiple Libraries are detected, pick the Library which loads the most layers for the model
|
||||
func pickBestPartialFitByLibrary(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel *int) discover.GpuInfoList {
|
||||
func pickBestPartialFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, numParallel *int) discover.GpuInfoList {
|
||||
if *numParallel <= 0 {
|
||||
*numParallel = 1
|
||||
req.opts.NumCtx = req.origNumCtx
|
||||
@ -749,7 +750,7 @@ func pickBestPartialFitByLibrary(req *LlmRequest, ggml *llm.GGML, gpus discover.
|
||||
var bestEstimate uint64
|
||||
var bestFit int
|
||||
for i, gl := range byLibrary {
|
||||
_, estimatedVRAM := llm.PredictServerFit(gl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
|
||||
_, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
|
||||
if estimatedVRAM > bestEstimate {
|
||||
bestEstimate = estimatedVRAM
|
||||
bestFit = i
|
||||
@ -822,9 +823,9 @@ func (s *Scheduler) expireRunner(model *Model) {
|
||||
|
||||
// If other runners are loaded, make sure the pending request will fit in system memory
|
||||
// If not, pick a runner to unload, else return nil and the request can be loaded
|
||||
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList) *runnerRef {
|
||||
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList) *runnerRef {
|
||||
slog.Debug("evaluating if CPU model load will fit in available system memory")
|
||||
estimate := llm.EstimateGPULayers(gpus, ggml, req.model.ProjectorPaths, req.opts)
|
||||
estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts)
|
||||
if estimate.TotalSize <= gpus[0].FreeMemory {
|
||||
slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory))
|
||||
return nil
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
"github.com/ollama/ollama/app/lifecycle"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
@ -37,7 +38,7 @@ func TestLoad(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
var ggml *llm.GGML // value not used in tests
|
||||
var f *ggml.GGML // value not used in tests
|
||||
req := &LlmRequest{
|
||||
ctx: ctx,
|
||||
model: &Model{ModelPath: "foo"},
|
||||
@ -47,11 +48,11 @@ func TestLoad(t *testing.T) {
|
||||
sessionDuration: &api.Duration{Duration: 2 * time.Second},
|
||||
}
|
||||
// Fail to load model first
|
||||
s.newServerFn = func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
return nil, errors.New("something failed to load model blah")
|
||||
}
|
||||
gpus := discover.GpuInfoList{}
|
||||
s.load(req, ggml, gpus, 0)
|
||||
s.load(req, f, gpus, 0)
|
||||
require.Empty(t, req.successCh)
|
||||
require.Len(t, req.errCh, 1)
|
||||
s.loadedMu.Lock()
|
||||
@ -61,10 +62,10 @@ func TestLoad(t *testing.T) {
|
||||
require.Contains(t, err.Error(), "this model may be incompatible")
|
||||
|
||||
server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}}
|
||||
s.newServerFn = func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
return server, nil
|
||||
}
|
||||
s.load(req, ggml, gpus, 0)
|
||||
s.load(req, f, gpus, 0)
|
||||
select {
|
||||
case err := <-req.errCh:
|
||||
require.NoError(t, err)
|
||||
@ -78,7 +79,7 @@ func TestLoad(t *testing.T) {
|
||||
|
||||
req.model.ModelPath = "dummy_model_path"
|
||||
server.waitResp = errors.New("wait failure")
|
||||
s.load(req, ggml, gpus, 0)
|
||||
s.load(req, f, gpus, 0)
|
||||
select {
|
||||
case err := <-req.errCh:
|
||||
require.Contains(t, err.Error(), "wait failure")
|
||||
@ -99,10 +100,10 @@ type reqBundle struct {
|
||||
ctxDone func()
|
||||
srv *mockLlm
|
||||
req *LlmRequest
|
||||
ggml *llm.GGML
|
||||
f *ggml.GGML
|
||||
}
|
||||
|
||||
func (scenario *reqBundle) newServer(gpus discover.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
func (scenario *reqBundle) newServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
return scenario.srv, nil
|
||||
}
|
||||
|
||||
@ -115,7 +116,7 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
|
||||
require.NoError(t, err)
|
||||
defer f.Close()
|
||||
|
||||
require.NoError(t, llm.WriteGGUF(f, llm.KV{
|
||||
require.NoError(t, ggml.WriteGGUF(f, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.context_length": uint32(32),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
@ -125,7 +126,7 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
|
||||
"tokenizer.ggml.tokens": []string{" "},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []llm.Tensor{
|
||||
}, []ggml.Tensor{
|
||||
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
|
||||
}))
|
||||
@ -133,7 +134,7 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
|
||||
|
||||
fname := f.Name()
|
||||
model := &Model{Name: modelName, ModelPath: fname}
|
||||
b.ggml, err = llm.LoadModel(model.ModelPath, 0)
|
||||
b.f, err = llm.LoadModel(model.ModelPath, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
if duration == nil {
|
||||
@ -174,7 +175,7 @@ func TestRequestsSameModelSameRequest(t *testing.T) {
|
||||
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
|
||||
b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0})
|
||||
b.req.model = a.req.model
|
||||
b.ggml = a.ggml
|
||||
b.f = a.f
|
||||
|
||||
s.newServerFn = a.newServer
|
||||
slog.Info("a")
|
||||
@ -218,7 +219,7 @@ func TestRequestsSimpleReloadSameModel(t *testing.T) {
|
||||
b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond})
|
||||
tmpModel := *a.req.model
|
||||
b.req.model = &tmpModel
|
||||
b.ggml = a.ggml
|
||||
b.f = a.f
|
||||
|
||||
s.newServerFn = a.newServer
|
||||
slog.Info("a")
|
||||
@ -419,13 +420,13 @@ func TestExpireRunner(t *testing.T) {
|
||||
sessionDuration: &api.Duration{Duration: 2 * time.Minute},
|
||||
}
|
||||
|
||||
var ggml *llm.GGML
|
||||
var f *ggml.GGML
|
||||
gpus := discover.GpuInfoList{}
|
||||
server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}}
|
||||
s.newServerFn = func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
return server, nil
|
||||
}
|
||||
s.load(req, ggml, gpus, 0)
|
||||
s.load(req, f, gpus, 0)
|
||||
|
||||
select {
|
||||
case err := <-req.errCh:
|
||||
@ -729,9 +730,9 @@ func TestHomogeneousGPUs(t *testing.T) {
|
||||
}
|
||||
s.getCpuFn = getCpuFn
|
||||
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
|
||||
s.newServerFn = func(gpus discover.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
require.Len(t, gpus, 1)
|
||||
return a.newServer(gpus, model, ggml, adapters, projectors, opts, numParallel)
|
||||
return a.newServer(gpus, model, f, adapters, projectors, opts, numParallel)
|
||||
}
|
||||
slog.Info("a")
|
||||
s.pendingReqCh <- a.req
|
||||
|
@ -14,7 +14,7 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
func TestNamed(t *testing.T) {
|
||||
@ -33,7 +33,7 @@ func TestNamed(t *testing.T) {
|
||||
|
||||
for k, v := range ss {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
kv := llm.KV{"tokenizer.chat_template": v}
|
||||
kv := ggml.KV{"tokenizer.chat_template": v}
|
||||
s := kv.ChatTemplate()
|
||||
r, err := Named(s)
|
||||
if err != nil {
|
||||
|
Loading…
x
Reference in New Issue
Block a user