From bb6fd02298bda99e3d77318d4f282bb2c30b3603 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Fri, 10 May 2024 10:17:12 -0700 Subject: [PATCH] Don't clamp ctx size in `PredictServerFit` (#4317) * dont clamp ctx size in `PredictServerFit` * minimum 4 context * remove context warning --- llm/memory.go | 11 +---------- llm/server.go | 10 +--------- server/sched.go | 4 ++++ 3 files changed, 6 insertions(+), 19 deletions(-) diff --git a/llm/memory.go b/llm/memory.go index 6890b08c5..df7081cf0 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -12,17 +12,8 @@ import ( // This algorithm looks for a complete fit to determine if we need to unload other models func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors []string, opts api.Options) (bool, uint64) { - var estimatedVRAM uint64 - if opts.NumCtx > int(ggml.KV().ContextLength()) { - slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength()) - opts.NumCtx = int(ggml.KV().ContextLength()) - } - - if opts.NumCtx < 4 { - opts.NumCtx = 4 - } - // Split up the GPUs by type and try them + var estimatedVRAM uint64 for _, gpus := range allGpus.ByLibrary() { var layerCount int layerCount, estimatedVRAM, _ = EstimateGPULayers(gpus, ggml, projectors, opts) diff --git a/llm/server.go b/llm/server.go index 8d0744a98..81a2dec4b 100644 --- a/llm/server.go +++ b/llm/server.go @@ -77,15 +77,7 @@ func LoadModel(model string) (*GGML, error) { // The gpu list must be a single family. func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) { var err error - if opts.NumCtx > int(ggml.KV().ContextLength()) { - slog.Warn("requested context length is greater than the model's training context window size", "requested", opts.NumCtx, "training size", ggml.KV().ContextLength()) - } - - if opts.NumCtx < 4 { - opts.NumCtx = 4 - } - - cpuRunner := "" + var cpuRunner string var estimatedVRAM uint64 var estimatedTotal uint64 var systemMemory uint64 diff --git a/server/sched.go b/server/sched.go index 96235ea5c..bbf333d7d 100644 --- a/server/sched.go +++ b/server/sched.go @@ -61,6 +61,10 @@ func InitScheduler(ctx context.Context) *Scheduler { // context must be canceled to decrement ref count and release the runner func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) { // allocate a large enough kv cache for all parallel requests + if opts.NumCtx < 4 { + opts.NumCtx = 4 + } + opts.NumCtx = opts.NumCtx * envconfig.NumParallel req := &LlmRequest{