diff --git a/integration/context_test.go b/integration/context_test.go index 24c57dcf26..ca6f16087c 100644 --- a/integration/context_test.go +++ b/integration/context_test.go @@ -36,7 +36,7 @@ func TestLongInputContext(t *testing.T) { if err := PullIfMissing(ctx, client, req.Model); err != nil { t.Fatalf("PullIfMissing failed: %v", err) } - DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second) + DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second) } func TestContextExhaustion(t *testing.T) { diff --git a/llm/server.go b/llm/server.go index e0a652ec0b..7bc2ca13df 100644 --- a/llm/server.go +++ b/llm/server.go @@ -173,6 +173,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a opts.NumCtx = int(trainCtx) } + opts.NumBatch = min(opts.NumBatch, opts.NumCtx) + loadRequest := LoadRequest{LoraPath: adapters, KvSize: opts.NumCtx * numParallel, BatchSize: opts.NumBatch, Parallel: numParallel, MultiUserCache: envconfig.MultiUserCache()} defaultThreads := discover.GetSystemInfo().GetOptimalThreadCount() diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index 1d1ca51806..f558f7b87a 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -34,8 +34,8 @@ type InputCache struct { func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) { numCtx := kvSize / int32(numSlots) - if numCtx < 1 { - return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots) + if int(numCtx) < batchSize { + return nil, fmt.Errorf("kv size must be at least as large as batch size * parallel (kv: %v batch: %v parallel: %v)", kvSize, batchSize, numSlots) } slots := make([]InputCacheSlot, numSlots)