From 5994e8e8fdaca3815d5a49f27d3fab33ac9ef51f Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 4 Sep 2025 09:09:07 -0700 Subject: [PATCH] embedding gemma model (#12181) * ollama: add embeddings --- model/bytepairencoding.go | 3 +- model/model.go | 5 ++ model/models/gemma3/embed.go | 73 +++++++++++++++++++++++++ model/models/gemma3/model.go | 7 ++- model/models/gemma3/model_text.go | 9 ++-- model/sentencepiece.go | 3 +- model/vocabulary.go | 4 +- runner/ollamarunner/cache.go | 6 ++- runner/ollamarunner/cache_test.go | 2 +- runner/ollamarunner/runner.go | 90 ++++++++++++++++++++++++++----- 10 files changed, 175 insertions(+), 27 deletions(-) create mode 100644 model/models/gemma3/embed.go diff --git a/model/bytepairencoding.go b/model/bytepairencoding.go index b2cb113257..e21564aa53 100644 --- a/model/bytepairencoding.go +++ b/model/bytepairencoding.go @@ -201,12 +201,11 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { } } - logutil.Trace("encoded", "string", s, "ids", ids) - if addSpecial && len(ids) > 0 { ids = bpe.vocab.addSpecials(ids) } + logutil.Trace("encoded", "string", s, "ids", ids) return ids, nil } diff --git a/model/model.go b/model/model.go index 30754143ae..3a72f09aa3 100644 --- a/model/model.go +++ b/model/model.go @@ -5,6 +5,7 @@ import ( "fmt" _ "image/jpeg" _ "image/png" + "math" "os" "reflect" "strconv" @@ -103,6 +104,10 @@ func New(modelPath string, params ml.BackendParams) (Model, error) { } arch := b.Config().Architecture() + if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 { + arch = arch + "_embed" + } + f, ok := models[arch] if !ok { return nil, fmt.Errorf("unsupported model architecture %q", arch) diff --git a/model/models/gemma3/embed.go b/model/models/gemma3/embed.go new file mode 100644 index 0000000000..16c299e22d --- /dev/null +++ b/model/models/gemma3/embed.go @@ -0,0 +1,73 @@ +package gemma3 + +import ( + "errors" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type embedModel struct { + model.Base + model.SentencePieceModel + + *TextModel + PoolingType uint32 + + Dense [2]*nn.Linear `gguf:"dense"` +} + +func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + batch.Outputs = batch.Positions // return all positions + hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) + + switch m.PoolingType { + case 0: // None + case 1: // Mean + hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx) + hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + default: + return nil, errors.New("unsupported pooling type") + } + + for _, dense := range m.Dense { + hiddenStates = dense.Forward(ctx, hiddenStates) + } + + return hiddenStates, nil +} + +func newEmbedModel(c fs.Config) (model.Model, error) { + m := &embedModel{ + SentencePieceModel: model.NewSentencePieceModel( + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{ + int32(c.Uint("tokenizer.ggml.eos_token_id")), + int32(c.Uint("tokenizer.ggml.eot_token_id", 106)), + }, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + ), + TextModel: newTextModel(c), + PoolingType: c.Uint("pooling_type", 0), + } + + m.Cache = kvcache.NewWrapperCache( + kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift), + kvcache.NewCausalCache(m.Shift), + ) + + return m, nil +} diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index a4c90d9cff..5c92b6bf95 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -141,12 +141,11 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil + hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) + return m.Output.Forward(ctx, hiddenStates), nil } func init() { model.Register("gemma3", New) + model.Register("gemma3_embed", newEmbedModel) } diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 70d7797e96..2a3b23939c 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -159,8 +159,11 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, return hiddenState.Add(ctx, residual) } -func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor { - hiddenState := m.TokenEmbedding.Forward(ctx, inputs) +func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor { + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + + hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize))) // set image embeddings @@ -198,5 +201,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor } hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) - return m.Output.Forward(ctx, hiddenState) + return hiddenState } diff --git a/model/sentencepiece.go b/model/sentencepiece.go index 4300f45f56..827ce00d94 100644 --- a/model/sentencepiece.go +++ b/model/sentencepiece.go @@ -181,12 +181,11 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) } } - logutil.Trace("encoded", "string", s, "ids", ids) - if addSpecial && len(ids) > 0 { ids = spm.vocab.addSpecials(ids) } + logutil.Trace("encoded", "string", s, "ids", ids) return ids, nil } diff --git a/model/vocabulary.go b/model/vocabulary.go index a86de58dfa..9b7fc789e5 100644 --- a/model/vocabulary.go +++ b/model/vocabulary.go @@ -49,7 +49,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 { slog.Warn("adding bos token to prompt which already has it", "id", v.BOS) } - slog.Debug("adding bos token to prompt", "id", v.BOS) + slog.Debug("adding bos token to prompt", "id", v.BOS[0]) ids = append([]int32{v.BOS[0]}, ids...) } @@ -58,7 +58,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 { slog.Warn("adding eos token to prompt which already has it", "id", v.EOS) } - slog.Debug("adding eos token to prompt", "id", v.EOS) + slog.Debug("adding eos token to prompt", "id", v.EOS[0]) ids = append(ids, v.EOS[0]) } diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index 8f30037c8e..955ef9b3d9 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -95,7 +95,7 @@ type InputCacheSlot struct { lastUsed time.Time } -func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*input.Input, error) { +func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*InputCacheSlot, []*input.Input, error) { var slot *InputCacheSlot var numPast int32 var err error @@ -113,6 +113,10 @@ func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*i return nil, nil, err } + if !cachePrompt { + numPast = 0 + } + slot.InUse = true slot.lastUsed = time.Now() diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index 49cb6c5474..c0693e8343 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -393,7 +393,7 @@ func TestLoadCacheSlot(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt) + slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt, true) // Check error state if (err != nil) != tt.wantErr { diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 9b54a37206..df3ce1d9f2 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -11,6 +11,7 @@ import ( "image" "log" "log/slog" + "math" "net" "net/http" "os" @@ -405,6 +406,8 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) { func (s *Server) run(ctx context.Context) { s.ready.Wait() + supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 + var activeBatch batchState for { select { @@ -418,7 +421,12 @@ func (s *Server) run(ctx context.Context) { if err != nil { panic(err) } - go s.computeBatch(activeBatch) + + if supportsAsync { + go s.computeBatch(activeBatch) + } else { + s.computeBatch(activeBatch) + } } } } @@ -670,7 +678,8 @@ func (s *Server) computeBatch(activeBatch batchState) { activeBatch.computeStartedCh <- struct{}{} }, activeBatch.modelOutput) - logits := activeBatch.modelOutput.Floats() + + outputs := activeBatch.modelOutput.Floats() logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id) @@ -689,16 +698,15 @@ func (s *Server) computeBatch(activeBatch batchState) { // if done processing the prompt, generate an embedding and return if seq.embeddingOnly { - // TODO(jessegross): Embedding support - slog.Warn("generation of embedding outputs not yet supported", "id", activeBatch.id, "seqIdx", i) + seq.embedding <- outputs s.removeSequence(i, llm.DoneReasonStop) continue } // sample a token - vocabSize := len(logits) / len(activeBatch.batch.Outputs) - logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(logits), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches) - token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize]) + vocabSize := len(outputs) / len(activeBatch.batch.Outputs) + logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches) + token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize]) if err != nil { s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err) return @@ -834,7 +842,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) if err != nil { s.mu.Unlock() s.seqsSem.Release(1) @@ -890,6 +898,67 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } } +func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { + if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 { + http.Error(w, "this model does not support embeddings", http.StatusNotImplemented) + return + } + + var req llm.EmbeddingRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true}) + if err != nil { + http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError) + return + } + + if err := s.seqsSem.Acquire(r.Context(), 1); err != nil { + if errors.Is(err, context.Canceled) { + slog.Info("aborting embedding request due to client closing the connection") + } else { + http.Error(w, fmt.Sprintf("failed to acquire semaphore: %v", err), http.StatusInternalServerError) + } + return + } + + s.mu.Lock() + found := false + for i, sq := range s.seqs { + if sq == nil { + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false) + if err != nil { + s.mu.Unlock() + s.seqsSem.Release(1) + http.Error(w, fmt.Sprintf("failed to load cache: %v", err), http.StatusInternalServerError) + return + } + + s.seqs[i] = seq + s.cond.Signal() + found = true + break + } + } + s.mu.Unlock() + + if !found { + s.seqsSem.Release(1) + http.Error(w, "could not find an available sequence", http.StatusInternalServerError) + return + } + + if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ + Embedding: <-seq.embedding, + }); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + } +} + func (s *Server) health(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{ @@ -1206,10 +1275,7 @@ func Execute(args []string) error { mux := http.NewServeMux() // TODO: support embeddings mux.HandleFunc("POST /load", server.load) - mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "this model does not support embeddings", http.StatusNotImplemented) - }) - + mux.HandleFunc("POST /embedding", server.embeddings) mux.HandleFunc("POST /completion", server.completion) mux.HandleFunc("GET /health", server.health)