This commit is contained in:
Bruce MacDonald 2025-03-10 17:10:48 -07:00
parent aebca4b70c
commit cb10c99297
2 changed files with 15 additions and 15 deletions

View File

@ -242,7 +242,7 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
}
type ErrReprocessInputs struct {
Inputs []input
Inputs []input.Input
}
func (e *ErrReprocessInputs) Error() string {
@ -279,12 +279,12 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
_ = c.cache.Remove(slot.Id, 0, -1)
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
newInputs := make([]input, numKeep+inputLen-(numKeep+discard))
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
// Reset the slot inputs since we've cleared the cache
slot.Inputs = []input{}
slot.Inputs = []input.Input{}
// Return error with inputs that need to be reprocessed
return &ErrReprocessInputs{Inputs: newInputs}

View File

@ -315,20 +315,20 @@ func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
}
// Stub implementations for other interface methods
func (m *mockCache) SetLayer(layer int) {}
func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { return nil, nil, nil }
func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {}
func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {}
func (m *mockCache) Close() {}
func (m *mockCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { return nil }
func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
func (m *mockCache) SetConfig(ml.CacheConfig) {}
func (m *mockCache) SetLayer(layer int) {}
func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { return nil, nil, nil }
func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {}
func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {}
func (m *mockCache) Close() {}
func (m *mockCache) StartForward(ctx ml.Context, opts input.Options) error { return nil }
func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
func (m *mockCache) SetConfig(ml.CacheConfig) {}
func TestShiftCacheSlot(t *testing.T) {
tests := []struct {
name string
numCtx int32
inputs []input
inputs []input.Input
numKeep int32
cacheErr bool
wantErr any
@ -337,7 +337,7 @@ func TestShiftCacheSlot(t *testing.T) {
{
name: "Normal shift",
numCtx: 10,
inputs: []input{{token: 1}, {token: 2}, {token: 3}, {token: 4}, {token: 5}, {token: 6}, {token: 7}, {token: 8}, {token: 9}, {token: 10}},
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2,
cacheErr: false, // No error
wantErr: nil,
@ -346,7 +346,7 @@ func TestShiftCacheSlot(t *testing.T) {
{
name: "Cache removal fails",
numCtx: 10,
inputs: []input{{token: 1}, {token: 2}, {token: 3}, {token: 4}, {token: 5}, {token: 6}, {token: 7}, {token: 8}, {token: 9}, {token: 10}},
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2,
cacheErr: true,
wantErr: &ErrReprocessInputs{},
@ -363,7 +363,7 @@ func TestShiftCacheSlot(t *testing.T) {
}
slot := &InputCacheSlot{
Id: 123,
Inputs: make([]input, len(tt.inputs)),
Inputs: make([]input.Input, len(tt.inputs)),
}
copy(slot.Inputs, tt.inputs)