mirror of
https://github.com/ollama/ollama.git
synced 2025-08-23 19:12:47 +02:00
usage example: go test --tags=integration,perf -count 1 ./integration -v -timeout 1h -run TestModelsPerf 2>&1 | tee int.log cat int.log | grep MODEL_PERF_HEADER | cut -f2- -d: > perf.csv cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv
267 lines
8.2 KiB
Go
267 lines
8.2 KiB
Go
//go:build integration && perf
|
|
|
|
package integration
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log/slog"
|
|
"math"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/format"
|
|
)
|
|
|
|
var (
|
|
// Models that don't work reliably with the large context prompt in this test case
|
|
longContextFlakes = []string{
|
|
"granite-code:latest",
|
|
"nemotron-mini:latest",
|
|
"falcon:latest", // 2k model
|
|
"falcon2:latest", // 2k model
|
|
"minicpm-v:latest",
|
|
"qwen:latest",
|
|
"solar-pro:latest",
|
|
}
|
|
)
|
|
|
|
// Note: this test case can take a long time to run, particularly on models with
|
|
// large contexts. Run with -timeout set to a large value to get reasonable coverage
|
|
// Example usage:
|
|
//
|
|
// go test --tags=integration,perf -count 1 ./integration -v -timeout 90m -run TestModelsPerf 2>&1 | tee int.log
|
|
// cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv
|
|
// cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv
|
|
func TestModelsPerf(t *testing.T) {
|
|
softTimeout, hardTimeout := getTimeouts(t)
|
|
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
|
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
|
defer cancel()
|
|
client, _, cleanup := InitServerConnection(ctx, t)
|
|
defer cleanup()
|
|
|
|
// TODO use info API eventually
|
|
var maxVram uint64
|
|
var err error
|
|
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
|
maxVram, err = strconv.ParseUint(s, 10, 64)
|
|
if err != nil {
|
|
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
|
|
}
|
|
} else {
|
|
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
|
|
}
|
|
|
|
data, err := ioutil.ReadFile(filepath.Join("testdata", "shakespeare.txt"))
|
|
if err != nil {
|
|
t.Fatalf("failed to open test data file: %s", err)
|
|
}
|
|
longPrompt := "summarize the following: " + string(data)
|
|
|
|
var chatModels []string
|
|
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
|
|
chatModels = ollamaEngineChatModels
|
|
} else {
|
|
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
|
|
}
|
|
|
|
for _, model := range chatModels {
|
|
t.Run(model, func(t *testing.T) {
|
|
if time.Now().Sub(started) > softTimeout {
|
|
t.Skip("skipping remaining tests to avoid excessive runtime")
|
|
}
|
|
if err := PullIfMissing(ctx, client, model); err != nil {
|
|
t.Fatalf("pull failed %s", err)
|
|
}
|
|
var maxContext int
|
|
|
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
|
|
if err != nil {
|
|
t.Fatalf("show failed: %s", err)
|
|
}
|
|
arch := resp.ModelInfo["general.architecture"].(string)
|
|
maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))
|
|
|
|
if maxVram > 0 {
|
|
resp, err := client.List(ctx)
|
|
if err != nil {
|
|
t.Fatalf("list models failed %v", err)
|
|
}
|
|
for _, m := range resp.Models {
|
|
// For these tests we want to exercise a some amount of overflow on the CPU
|
|
if m.Name == model && float32(m.Size)*0.75 > float32(maxVram) {
|
|
t.Skipf("model %s is too large %s for available VRAM %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
|
|
}
|
|
}
|
|
}
|
|
slog.Info("scneario", "model", model, "max_context", maxContext)
|
|
loaded := false
|
|
defer func() {
|
|
// best effort unload once we're done with the model
|
|
if loaded {
|
|
client.Generate(ctx, &api.GenerateRequest{Model: model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
|
|
}
|
|
}()
|
|
|
|
// Some models don't handle the long context data well so skip them to avoid flaky test results
|
|
longContextFlake := false
|
|
for _, flake := range longContextFlakes {
|
|
if model == flake {
|
|
longContextFlake = true
|
|
break
|
|
}
|
|
}
|
|
|
|
// iterate through a few context sizes for coverage without excessive runtime
|
|
var contexts []int
|
|
keepGoing := true
|
|
if maxContext > 16384 {
|
|
contexts = []int{4096, 8192, 16384, maxContext}
|
|
} else if maxContext > 8192 {
|
|
contexts = []int{4096, 8192, maxContext}
|
|
} else if maxContext > 4096 {
|
|
contexts = []int{4096, maxContext}
|
|
} else if maxContext > 0 {
|
|
contexts = []int{maxContext}
|
|
} else {
|
|
t.Fatal("unknown max context size")
|
|
}
|
|
for _, numCtx := range contexts {
|
|
if !keepGoing && numCtx > 8192 { // Always try up to 8k before bailing out
|
|
break
|
|
}
|
|
skipLongPrompt := false
|
|
|
|
// Workaround bug 11172 temporarily...
|
|
maxPrompt := longPrompt
|
|
// If we fill the context too full with the prompt, many models
|
|
// quickly hit context shifting and go bad.
|
|
if len(maxPrompt) > numCtx*2 { // typically yields ~1/2 full context
|
|
maxPrompt = maxPrompt[:numCtx*2]
|
|
}
|
|
|
|
testCases := []struct {
|
|
prompt string
|
|
anyResp []string
|
|
}{
|
|
{"why is the sky blue?", []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}},
|
|
{maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy"}},
|
|
}
|
|
var gpuPercent int
|
|
for _, tc := range testCases {
|
|
if len(tc.prompt) > 100 && (longContextFlake || skipLongPrompt) {
|
|
slog.Info("skipping long prompt", "model", model, "num_ctx", numCtx, "gpu_percent", gpuPercent)
|
|
continue
|
|
}
|
|
req := api.GenerateRequest{
|
|
Model: model,
|
|
Prompt: tc.prompt,
|
|
KeepAlive: &api.Duration{Duration: 20 * time.Second}, // long enough to ensure a ps returns
|
|
Options: map[string]interface{}{
|
|
"temperature": 0,
|
|
"seed": 123,
|
|
"num_ctx": numCtx,
|
|
},
|
|
}
|
|
atLeastOne := false
|
|
var resp api.GenerateResponse
|
|
|
|
stream := false
|
|
req.Stream = &stream
|
|
|
|
// Avoid potentially getting stuck indefinitely
|
|
limit := 5 * time.Minute
|
|
genCtx, cancel := context.WithDeadlineCause(
|
|
ctx,
|
|
time.Now().Add(limit),
|
|
fmt.Errorf("generate on model %s with ctx %d took longer than %v", model, numCtx, limit),
|
|
)
|
|
defer cancel()
|
|
|
|
err = client.Generate(genCtx, &req, func(rsp api.GenerateResponse) error {
|
|
resp = rsp
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
// Avoid excessive test runs, but don't consider a failure with massive context
|
|
if numCtx > 16384 && strings.Contains(err.Error(), "took longer") {
|
|
slog.Warn("max context was taking too long, skipping", "error", err)
|
|
keepGoing = false
|
|
skipLongPrompt = true
|
|
continue
|
|
}
|
|
t.Fatalf("generate error: ctx:%d err:%s", numCtx, err)
|
|
}
|
|
loaded = true
|
|
for _, expResp := range tc.anyResp {
|
|
if strings.Contains(strings.ToLower(resp.Response), expResp) {
|
|
atLeastOne = true
|
|
break
|
|
}
|
|
}
|
|
if !atLeastOne {
|
|
t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Response)
|
|
}
|
|
models, err := client.ListRunning(ctx)
|
|
if err != nil {
|
|
slog.Warn("failed to list running models", "error", err)
|
|
continue
|
|
}
|
|
if len(models.Models) > 1 {
|
|
slog.Warn("multiple models loaded, may impact performance results", "loaded", models.Models)
|
|
}
|
|
for _, m := range models.Models {
|
|
if m.Name == model {
|
|
if m.SizeVRAM == 0 {
|
|
slog.Info("Model fully loaded into CPU")
|
|
gpuPercent = 0
|
|
keepGoing = false
|
|
skipLongPrompt = true
|
|
} else if m.SizeVRAM == m.Size {
|
|
slog.Info("Model fully loaded into GPU")
|
|
gpuPercent = 100
|
|
} else {
|
|
sizeCPU := m.Size - m.SizeVRAM
|
|
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
|
|
gpuPercent = int(100 - cpuPercent)
|
|
slog.Info("Model split between CPU/GPU", "CPU", cpuPercent, "GPU", gpuPercent)
|
|
keepGoing = false
|
|
|
|
// Heuristic to avoid excessive test run time
|
|
if gpuPercent < 90 {
|
|
skipLongPrompt = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n",
|
|
"MODEL",
|
|
"CONTEXT",
|
|
"GPU PERCENT",
|
|
"PROMPT COUNT",
|
|
"LOAD TIME",
|
|
"PROMPT EVAL TPS",
|
|
"EVAL TPS",
|
|
)
|
|
fmt.Fprintf(os.Stderr, "MODEL_PERF_DATA:%s,%d,%d,%d,%0.2f,%0.2f,%0.2f\n",
|
|
model,
|
|
numCtx,
|
|
gpuPercent,
|
|
resp.PromptEvalCount,
|
|
float64(resp.LoadDuration)/1000000000.0,
|
|
float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0),
|
|
float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0),
|
|
)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|