diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index adcb3f738..cf5e6b911 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -89,7 +89,7 @@ type InputCacheSlot struct { lastUsed time.Time } -func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) { +func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) { var slot *InputCacheSlot var numPast int32 var err error @@ -107,11 +107,6 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp return nil, nil, err } - // TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved? - if !cachePrompt { - numPast = 0 - } - slot.InUse = true slot.lastUsed = time.Now() diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index 0a1b73f5a..f8925d119 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -297,3 +297,131 @@ func TestShiftDiscard(t *testing.T) { }) } } + +func TestLoadCacheSlot(t *testing.T) { + tests := []struct { + name string + cache InputCache + prompt []input.Input + wantErr bool + expectedSlotId int + expectedPrompt int // expected length of remaining prompt + }{ + { + name: "Basic cache hit - single user", + cache: InputCache{ + multiUserCache: false, + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input.Input{}, + InUse: false, + lastUsed: time.Now().Add(-2 * time.Second), + }, + }, + }, + prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + wantErr: false, + expectedSlotId: 0, + expectedPrompt: 1, // Only token 3 remains + }, + { + name: "Basic cache hit - multi user", + cache: InputCache{ + multiUserCache: true, + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input.Input{}, + InUse: false, + lastUsed: time.Now().Add(-2 * time.Second), + }, + }, + }, + prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + wantErr: false, + expectedSlotId: 0, + expectedPrompt: 1, // Only token 3 remains + }, + { + name: "Exact match - leave one input", + cache: InputCache{ + multiUserCache: false, + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + }, + }, + prompt: []input.Input{{Token: 1}, {Token: 2}}, + wantErr: false, + expectedSlotId: 0, + expectedPrompt: 1, // Should leave 1 token for sampling + }, + { + name: "No available slots", + cache: InputCache{ + multiUserCache: false, + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, + InUse: true, + lastUsed: time.Now().Add(-time.Second), + }, + }, + }, + prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + wantErr: true, + expectedSlotId: -1, + expectedPrompt: -1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt) + + // Check error state + if (err != nil) != tt.wantErr { + t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr { + return // Skip further checks if we expected an error + } + + // Verify slot ID + if slot.Id != tt.expectedSlotId { + t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId) + } + + // Verify slot is now marked in use + if !slot.InUse { + t.Errorf("LoadCacheSlot() slot not marked InUse") + } + + // Verify remaining prompt length + if len(remainingPrompt) != tt.expectedPrompt { + t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v", + len(remainingPrompt), tt.expectedPrompt) + } + }) + } +} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index d4c24556c..98fcf2c04 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -590,7 +590,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)