diff --git a/llama/runner/image.go b/llama/runner/image.go index d50645e81..3b5621860 100644 --- a/llama/runner/image.go +++ b/llama/runner/image.go @@ -5,6 +5,7 @@ import ( "fmt" "hash/maphash" "log/slog" + "slices" "sync" "time" @@ -96,6 +97,16 @@ func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int { } } +func (c *ImageContext) NeedCrossAttention(inputs ...input) bool { + if c == nil || c.mllama == nil { + return false + } + + return slices.ContainsFunc(inputs, func(input input) bool { + return input.embed != nil + }) +} + type imageCache struct { key uint64 val [][]float32 diff --git a/llama/runner/runner.go b/llama/runner/runner.go index a137f8795..a7e0e3b01 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -52,6 +52,10 @@ type Sequence struct { // input cache being used by this sequence cache *InputCacheSlot + // does this sequence require cross-attention layers to be processed? - if we have seen + // an image for certain multi-modal models + crossAttention bool + // channel to send responses over responses chan string @@ -287,7 +291,6 @@ func flushPending(seq *Sequence) bool { func (s *Server) removeSequence(seqIndex int, reason string) { seq := s.seqs[seqIndex] - s.lc.SetCrossAttention(false) flushPending(seq) seq.doneReason = reason close(seq.responses) @@ -334,6 +337,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) defer s.mu.Unlock() var batch *llama.Batch + crossAttention := false seqIdx := s.nextSeq - 1 for range s.seqs { @@ -367,8 +371,9 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) batch = tokenBatch } else { batch = embedBatch + seq.crossAttention = s.image.NeedCrossAttention(input) } - } else if embedding != batch.IsEmbedding() { + } else if embedding != batch.IsEmbedding() || crossAttention != seq.crossAttention { s.nextSeq = seqIdx break } @@ -378,6 +383,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) break } + crossAttention = seq.crossAttention batch.Add(input.token, input.embed, seq.numPast, []int{seq.cache.Id}, numInputsProcessed+1 == len(seq.inputs)) seq.numPast++ numInputsProcessed++ @@ -394,6 +400,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) return } + s.lc.SetCrossAttention(crossAttention) + err := s.lc.Decode(batch) if err != nil { slog.Error("failed to decode batch", "error", err) @@ -605,13 +613,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { s.mu.Lock() for i, sq := range s.seqs { if sq == nil { - for _, input := range seq.inputs { - if input.embed != nil { - s.lc.SetCrossAttention(true) - break - } - } - seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) if err != nil { s.mu.Unlock() @@ -619,6 +620,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } + seq.crossAttention = s.image.NeedCrossAttention(seq.cache.Inputs...) + s.seqs[i] = seq s.cond.Signal() break