//go:build integration && perf package integration import ( "context" "fmt" "io/ioutil" "log/slog" "math" "os" "path/filepath" "strconv" "strings" "testing" "time" "github.com/ollama/ollama/api" "github.com/ollama/ollama/format" ) var ( // Models that don't work reliably with the large context prompt in this test case longContextFlakes = []string{ "granite-code:latest", "nemotron-mini:latest", "falcon:latest", // 2k model "falcon2:latest", // 2k model "minicpm-v:latest", "qwen:latest", "solar-pro:latest", } ) // Note: this test case can take a long time to run, particularly on models with // large contexts. Run with -timeout set to a large value to get reasonable coverage // Example usage: // // go test --tags=integration,perf -count 1 ./integration -v -timeout 90m -run TestModelsPerf 2>&1 | tee int.log // cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv // cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv func TestModelsPerf(t *testing.T) { softTimeout, hardTimeout := getTimeouts(t) slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout) ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) defer cancel() client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() // TODO use info API eventually var maxVram uint64 var err error if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" { maxVram, err = strconv.ParseUint(s, 10, 64) if err != nil { t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err) } } else { slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...") } data, err := ioutil.ReadFile(filepath.Join("testdata", "shakespeare.txt")) if err != nil { t.Fatalf("failed to open test data file: %s", err) } longPrompt := "summarize the following: " + string(data) var chatModels []string if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" { chatModels = ollamaEngineChatModels } else { chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...) } for _, model := range chatModels { t.Run(model, func(t *testing.T) { if time.Now().Sub(started) > softTimeout { t.Skip("skipping remaining tests to avoid excessive runtime") } if err := PullIfMissing(ctx, client, model); err != nil { t.Fatalf("pull failed %s", err) } var maxContext int resp, err := client.Show(ctx, &api.ShowRequest{Model: model}) if err != nil { t.Fatalf("show failed: %s", err) } arch := resp.ModelInfo["general.architecture"].(string) maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64)) if maxVram > 0 { resp, err := client.List(ctx) if err != nil { t.Fatalf("list models failed %v", err) } for _, m := range resp.Models { // For these tests we want to exercise a some amount of overflow on the CPU if m.Name == model && float32(m.Size)*0.75 > float32(maxVram) { t.Skipf("model %s is too large %s for available VRAM %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram))) } } } slog.Info("scneario", "model", model, "max_context", maxContext) loaded := false defer func() { // best effort unload once we're done with the model if loaded { client.Generate(ctx, &api.GenerateRequest{Model: model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil }) } }() // Some models don't handle the long context data well so skip them to avoid flaky test results longContextFlake := false for _, flake := range longContextFlakes { if model == flake { longContextFlake = true break } } // iterate through a few context sizes for coverage without excessive runtime var contexts []int keepGoing := true if maxContext > 16384 { contexts = []int{4096, 8192, 16384, maxContext} } else if maxContext > 8192 { contexts = []int{4096, 8192, maxContext} } else if maxContext > 4096 { contexts = []int{4096, maxContext} } else if maxContext > 0 { contexts = []int{maxContext} } else { t.Fatal("unknown max context size") } for _, numCtx := range contexts { if !keepGoing && numCtx > 8192 { // Always try up to 8k before bailing out break } skipLongPrompt := false // Workaround bug 11172 temporarily... maxPrompt := longPrompt // If we fill the context too full with the prompt, many models // quickly hit context shifting and go bad. if len(maxPrompt) > numCtx*2 { // typically yields ~1/2 full context maxPrompt = maxPrompt[:numCtx*2] } testCases := []struct { prompt string anyResp []string }{ {"why is the sky blue?", []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}}, {maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy"}}, } var gpuPercent int for _, tc := range testCases { if len(tc.prompt) > 100 && (longContextFlake || skipLongPrompt) { slog.Info("skipping long prompt", "model", model, "num_ctx", numCtx, "gpu_percent", gpuPercent) continue } req := api.GenerateRequest{ Model: model, Prompt: tc.prompt, KeepAlive: &api.Duration{Duration: 20 * time.Second}, // long enough to ensure a ps returns Options: map[string]interface{}{ "temperature": 0, "seed": 123, "num_ctx": numCtx, }, } atLeastOne := false var resp api.GenerateResponse stream := false req.Stream = &stream // Avoid potentially getting stuck indefinitely limit := 5 * time.Minute genCtx, cancel := context.WithDeadlineCause( ctx, time.Now().Add(limit), fmt.Errorf("generate on model %s with ctx %d took longer than %v", model, numCtx, limit), ) defer cancel() err = client.Generate(genCtx, &req, func(rsp api.GenerateResponse) error { resp = rsp return nil }) if err != nil { // Avoid excessive test runs, but don't consider a failure with massive context if numCtx > 16384 && strings.Contains(err.Error(), "took longer") { slog.Warn("max context was taking too long, skipping", "error", err) keepGoing = false skipLongPrompt = true continue } t.Fatalf("generate error: ctx:%d err:%s", numCtx, err) } loaded = true for _, expResp := range tc.anyResp { if strings.Contains(strings.ToLower(resp.Response), expResp) { atLeastOne = true break } } if !atLeastOne { t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Response) } models, err := client.ListRunning(ctx) if err != nil { slog.Warn("failed to list running models", "error", err) continue } if len(models.Models) > 1 { slog.Warn("multiple models loaded, may impact performance results", "loaded", models.Models) } for _, m := range models.Models { if m.Name == model { if m.SizeVRAM == 0 { slog.Info("Model fully loaded into CPU") gpuPercent = 0 keepGoing = false skipLongPrompt = true } else if m.SizeVRAM == m.Size { slog.Info("Model fully loaded into GPU") gpuPercent = 100 } else { sizeCPU := m.Size - m.SizeVRAM cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100) gpuPercent = int(100 - cpuPercent) slog.Info("Model split between CPU/GPU", "CPU", cpuPercent, "GPU", gpuPercent) keepGoing = false // Heuristic to avoid excessive test run time if gpuPercent < 90 { skipLongPrompt = true } } } } fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n", "MODEL", "CONTEXT", "GPU PERCENT", "PROMPT COUNT", "LOAD TIME", "PROMPT EVAL TPS", "EVAL TPS", ) fmt.Fprintf(os.Stderr, "MODEL_PERF_DATA:%s,%d,%d,%d,%0.2f,%0.2f,%0.2f\n", model, numCtx, gpuPercent, resp.PromptEvalCount, float64(resp.LoadDuration)/1000000000.0, float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0), float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0), ) } } }) } }