From cf1dbcfc5af6448f5f39df8513fcba2a9b7c71ed Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 10 Jan 2025 10:32:38 -0800 Subject: [PATCH] next bert --- ml/backend/ggml/ggml.go | 21 +++--- model/bert/model.go | 157 +++++++++++++++++++++++++++++++++++++++ model/bert/model_test.go | 75 +++++++++++++++++++ 3 files changed, 243 insertions(+), 10 deletions(-) create mode 100644 model/bert/model.go create mode 100644 model/bert/model_test.go diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 8274f3ebb..a3a4bec49 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -23,7 +23,7 @@ import ( "github.com/ollama/ollama/ml" "golang.org/x/sync/errgroup" - "github.com/ollama/ollama/ml/backend/ggml/ggml/src" + ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" ) type device struct { @@ -249,8 +249,8 @@ func (c *Context) Compute(t ml.Tensor) ml.Tensor { backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t) - t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t)) - C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) + t.(*Tensor).bytes = make([]byte, C.ggml_nbytes(t.(*Tensor).t)) + C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).bytes[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) return t } @@ -313,8 +313,8 @@ func (c *Context) Close() error { } type Tensor struct { - t *C.struct_ggml_tensor - data []byte + t *C.struct_ggml_tensor + bytes []byte } func (t *Tensor) LogValue() slog.Value { @@ -343,17 +343,18 @@ func (t *Tensor) Shape() []int64 { } func (t *Tensor) Bytes() []byte { - if bts := C.ggml_get_data(t.t); bts != nil { - return C.GoBytes(bts, C.int(C.ggml_nbytes(t.t))) + if t.bytes == nil { + cbytes := C.ggml_get_data(t.t) + t.bytes = C.GoBytes(unsafe.Pointer(cbytes), C.int(C.ggml_nbytes(t.t))) } - return nil + return t.bytes } func (t *Tensor) Floats() (f32s []float32) { - if t.data != nil { + if t.bytes != nil { f32s = make([]float32, C.ggml_nelements(t.t)) - _ = binary.Read(bytes.NewReader(t.data), binary.LittleEndian, f32s) + _ = binary.Read(bytes.NewReader(t.bytes), binary.LittleEndian, f32s) } return diff --git a/model/bert/model.go b/model/bert/model.go new file mode 100644 index 000000000..e4f0bec27 --- /dev/null +++ b/model/bert/model.go @@ -0,0 +1,157 @@ +package bert + +import ( + "fmt" + "math" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model" +) + +func init() { + model.Register("bert", New) +} + +type Options struct { + hiddenSize, numHeads int64 + eps float32 +} + +type Model struct { + model.Base + model.BytePairEncoding + + TokenEmbedding *nn.Embedding `ggml:"token_embd"` + TypeEmbedding *nn.Embedding `ggml:"type_embd,alt:token_types"` + PositionEmbedding *nn.Embedding `ggml:"position_embd"` + TokenEmbeddingNorm *nn.LayerNorm `ggml:"token_embd_norm"` + + Layers []EncoderLayer `ggml:"blk"` + + *Options +} + +// Forward implements model.Model. +func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { + inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs())) + if err != nil { + return nil, err + } + fmt.Println("inputs", inputs.Shape(), ml.Dump(inputs)) + + types, err := ctx.FromIntSlice([]int32{0}, 1) + if err != nil { + return nil, err + } + fmt.Println("types", types.Shape(), ml.Dump(types)) + + positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions())) + if err != nil { + return nil, err + } + fmt.Println("positions", positions.Shape(), ml.Dump(positions)) + + hiddenState := m.TokenEmbedding.Forward(ctx, inputs) + fmt.Println("TokenEmbedding.Forward", hiddenState.Shape(), ml.Dump(hiddenState)) + return hiddenState, nil + hiddenState = hiddenState.Add(ctx, m.TypeEmbedding.Forward(ctx, types)) + fmt.Println("TypeEmbedding.Forward", hiddenState.Shape(), ml.Dump(hiddenState)) + hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positions)) + fmt.Println("PositionEmbedding.Forward", hiddenState.Shape(), ml.Dump(hiddenState)) + hiddenState = m.TokenEmbeddingNorm.Forward(ctx, hiddenState, m.eps) + fmt.Println("TokenEmbeddingNorm.Forward", hiddenState.Shape(), ml.Dump(hiddenState)) + + for i, layer := range m.Layers { + hiddenState = layer.Forward(ctx, hiddenState, positions, opts.Cache.Sub(i), m.Options) + fmt.Println("EncoderLayer.Forward", i, hiddenState.Shape(), ml.Dump(hiddenState)) + } + + return hiddenState, nil +} + +type EncoderLayer struct { + *SelfAttention + MLPNorm *nn.LayerNorm `ggml:"attn_output_norm"` + *MLP + LayerOutputNorm *nn.LayerNorm `ggml:"ffn_output_norm"` +} + +func (e *EncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor { + residual := hiddenState + + hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) + hiddenState = hiddenState.Add(ctx, residual) + residual = hiddenState + + hiddenState = e.MLPNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = e.MLP.Forward(ctx, hiddenState, opts) + hiddenState = hiddenState.Add(ctx, residual) + return e.LayerOutputNorm.Forward(ctx, hiddenState, opts.eps) +} + +type SelfAttention struct { + Query *nn.Linear `ggml:"attn_q"` + Key *nn.Linear `ggml:"attn_k"` + Value *nn.Linear `ggml:"attn_v"` + Output *nn.Linear `ggml:"attn_output"` +} + +func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor { + batchSize := hiddenState.Dim(1) + headDim := opts.hiddenSize / opts.numHeads + + query := sa.Query.Forward(ctx, hiddenState) + query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) + + key := sa.Key.Forward(ctx, hiddenState) + key = key.Reshape(ctx, opts.numHeads, headDim, batchSize) + + value := sa.Value.Forward(ctx, hiddenState) + value = value.Reshape(ctx, headDim, opts.numHeads, batchSize) + + key, value = cache.Put(ctx, key, value, cache.Options) + + query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + + scores := key.Mulmat(ctx, query) + scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) + scores = scores.Softmax(ctx) + + attention := value.Mulmat(ctx, scores) + attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) + + return sa.Output.Forward(ctx, attention) +} + +type MLP struct { + Up *nn.Linear `ggml:"ffn_up"` + Down *nn.Linear `ggml:"ffn_down"` +} + +func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { + return mlp.Down.Forward(ctx, mlp.Up.Forward(ctx, hiddenState).GELU(ctx)) +} + +func New(c ml.Config) (model.Model, error) { + return &Model{ + BytePairEncoding: model.NewBytePairEncoding( + c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Uints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + BOS: c.Uint("tokenizer.ggml.bos_token_id"), + EOS: c.Uint("tokenizer.ggml.eos_token_id"), + }, + ), + Options: &Options{ + hiddenSize: int64(c.Uint("embedding_length")), + numHeads: int64(c.Uint("attention.head_count")), + eps: c.Float("attention.layer_norm_epsilon"), + }, + }, nil +} diff --git a/model/bert/model_test.go b/model/bert/model_test.go new file mode 100644 index 000000000..705abb7d1 --- /dev/null +++ b/model/bert/model_test.go @@ -0,0 +1,75 @@ +package bert_test + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model" +) + +func blob(t *testing.T, tag string) string { + t.Helper() + home, err := os.UserHomeDir() + if err != nil { + t.Fatal(err) + } + + p := filepath.Join(home, ".ollama", "models") + manifestBytes, err := os.ReadFile(filepath.Join(p, "manifests", "registry.ollama.ai", "library", "all-minilm", tag)) + if err != nil { + t.Fatal(err) + } + + var manifest struct { + Layers []struct { + MediaType string `json:"mediaType"` + Digest string `json:"digest"` + } + } + + if err := json.Unmarshal(manifestBytes, &manifest); err != nil { + t.Fatal(err) + } + + var digest string + for _, layer := range manifest.Layers { + if layer.MediaType == "application/vnd.ollama.image.model" { + digest = layer.Digest + break + } + } + + if digest == "" { + t.Fatal("no model layer found") + } + + return filepath.Join(p, "blobs", strings.ReplaceAll(digest, ":", "-")) +} + +func TestEmbedding(t *testing.T) { + m, err := model.New(blob(t, "latest")) + if err != nil { + t.Fatal(err) + } + + text, err := os.ReadFile(filepath.Join("..", "testdata", "war-and-peace.txt")) + if err != nil { + t.Fatal(err) + } + + inputIDs, err := m.(model.TextProcessor).Encode(string(text)) + if err != nil { + t.Fatal(err) + } + + logit, err := model.Forward(m, model.WithInputIDs(inputIDs)) + if err != nil { + t.Fatal(err) + } + + t.Log(ml.Dump(logit)) +}