From 314573bfe8afd6e93389ec519699da20285a38dc Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Mon, 24 Feb 2025 13:26:35 -0800 Subject: [PATCH] config: allow setting context length through env var (#8938) * envconfig: allow setting context length through env var --- api/types.go | 4 +++- envconfig/config.go | 3 +++ envconfig/config_test.go | 16 ++++++++++++++++ llm/memory_test.go | 1 + 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/api/types.go b/api/types.go index f4c5b1058..637ca2042 100644 --- a/api/types.go +++ b/api/types.go @@ -10,6 +10,8 @@ import ( "strconv" "strings" "time" + + "github.com/ollama/ollama/envconfig" ) // StatusError is an error with an HTTP status code and message. @@ -609,7 +611,7 @@ func DefaultOptions() Options { Runner: Runner{ // options set when the model is loaded - NumCtx: 2048, + NumCtx: int(envconfig.ContextLength()), NumBatch: 512, NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically NumThread: 0, // let the runtime decide diff --git a/envconfig/config.go b/envconfig/config.go index d867bdac2..6117aa264 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -167,6 +167,8 @@ var ( MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE") // Enable the new Ollama engine NewEngine = Bool("OLLAMA_NEW_ENGINE") + // ContextLength sets the default context length + ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 2048) ) func String(s string) func() string { @@ -252,6 +254,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, + "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 2048)"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, // Informational diff --git a/envconfig/config_test.go b/envconfig/config_test.go index 993ddd9ca..385dab5f1 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -272,3 +272,19 @@ func TestVar(t *testing.T) { }) } } + +func TestContextLength(t *testing.T) { + cases := map[string]uint{ + "": 2048, + "4096": 4096, + } + + for k, v := range cases { + t.Run(k, func(t *testing.T) { + t.Setenv("OLLAMA_CONTEXT_LENGTH", k) + if i := ContextLength(); i != v { + t.Errorf("%s: expected %d, got %d", k, v, i) + } + }) + } +} diff --git a/llm/memory_test.go b/llm/memory_test.go index e49d25410..40cc01dff 100644 --- a/llm/memory_test.go +++ b/llm/memory_test.go @@ -17,6 +17,7 @@ import ( func TestEstimateGPULayers(t *testing.T) { t.Setenv("OLLAMA_DEBUG", "1") t.Setenv("OLLAMA_KV_CACHE_TYPE", "") // Ensure default f16 + t.Setenv("OLLAMA_CONTEXT_LENGTH", "2048") modelName := "dummy" f, err := os.CreateTemp(t.TempDir(), modelName)