diff --git a/benchmark/server_benchmark_test.go b/benchmark/server_benchmark_test.go new file mode 100644 index 000000000..b27aa630b --- /dev/null +++ b/benchmark/server_benchmark_test.go @@ -0,0 +1,178 @@ +package benchmark + +import ( + "context" + "flag" + "fmt" + "testing" + "time" + + "github.com/ollama/ollama/api" +) + +// Command line flags +var modelFlag string + +func init() { + flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark") + flag.Lookup("m").DefValue = "model" +} + +// modelName returns the model name from flags, failing the test if not set +func modelName(b *testing.B) string { + if modelFlag == "" { + b.Fatal("Error: -m flag is required for benchmark tests") + } + return modelFlag +} + +type TestCase struct { + name string + prompt string + maxTokens int +} + +// runGenerateBenchmark contains the common generate and metrics logic +func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) { + start := time.Now() + var ttft time.Duration + var metrics api.Metrics + + err := client.Generate(ctx, req, func(resp api.GenerateResponse) error { + if ttft == 0 && resp.Response != "" { + ttft = time.Since(start) + } + if resp.Done { + metrics = resp.Metrics + } + return nil + }) + + // Report custom metrics as part of the benchmark results + b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms") + b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms") + + // Token throughput metrics + promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds() + genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds() + b.ReportMetric(promptThroughput, "prompt_tok/s") + b.ReportMetric(genThroughput, "gen_tok/s") + + // Token counts + b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens") + b.ReportMetric(float64(metrics.EvalCount), "gen_tokens") + if err != nil { + b.Fatal(err) + } +} + +// BenchmarkColdStart runs benchmarks with model loading from cold state +func BenchmarkColdStart(b *testing.B) { + client := setup(b) + tests := []TestCase{ + {"short_prompt", "Write a long story", 100}, + {"medium_prompt", "Write a detailed economic analysis", 500}, + {"long_prompt", "Write a comprehensive AI research paper", 1000}, + } + m := modelName(b) + + for _, tt := range tests { + b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) { + ctx := context.Background() + + // Set number of tokens as our throughput metric + b.SetBytes(int64(tt.maxTokens)) + + for b.Loop() { + b.StopTimer() + // Ensure model is unloaded before each iteration + unload(client, m, b) + b.StartTimer() + + req := &api.GenerateRequest{ + Model: m, + Prompt: tt.prompt, + Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1}, + } + + runGenerateBenchmark(b, ctx, client, req) + } + }) + } +} + +// BenchmarkWarmStart runs benchmarks with pre-loaded model +func BenchmarkWarmStart(b *testing.B) { + client := setup(b) + tests := []TestCase{ + {"short_prompt", "Write a long story", 100}, + {"medium_prompt", "Write a detailed economic analysis", 500}, + {"long_prompt", "Write a comprehensive AI research paper", 1000}, + } + m := modelName(b) + + for _, tt := range tests { + b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) { + ctx := context.Background() + + // Pre-warm the model + warmup(client, m, tt.prompt, b) + + // Set number of tokens as our throughput metric + b.SetBytes(int64(tt.maxTokens)) + + for b.Loop() { + req := &api.GenerateRequest{ + Model: m, + Prompt: tt.prompt, + Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1}, + } + + runGenerateBenchmark(b, ctx, client, req) + } + }) + } +} + +// setup verifies server and model availability +func setup(b *testing.B) *api.Client { + client, err := api.ClientFromEnvironment() + if err != nil { + b.Fatal(err) + } + if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil { + b.Fatalf("Model unavailable: %v", err) + } + + return client +} + +// warmup ensures the model is loaded and warmed up +func warmup(client *api.Client, model string, prompt string, b *testing.B) { + for range 3 { + err := client.Generate( + context.Background(), + &api.GenerateRequest{ + Model: model, + Prompt: prompt, + Options: map[string]interface{}{"num_predict": 50, "temperature": 0.1}, + }, + func(api.GenerateResponse) error { return nil }, + ) + if err != nil { + b.Logf("Error during model warm-up: %v", err) + } + } +} + +// unload forces model unloading using KeepAlive: 0 parameter +func unload(client *api.Client, model string, b *testing.B) { + req := &api.GenerateRequest{ + Model: model, + KeepAlive: &api.Duration{Duration: 0}, + } + if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil { + b.Logf("Unload error: %v", err) + } + time.Sleep(1 * time.Second) +} diff --git a/docs/benchmark.md b/docs/benchmark.md new file mode 100644 index 000000000..a7bed8083 --- /dev/null +++ b/docs/benchmark.md @@ -0,0 +1,59 @@ +# Benchmark + +Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes. + +## When to use + +Run these benchmarks when: +- Making changes to the model inference engine +- Modifying model loading/unloading logic +- Changing prompt processing or token generation code +- Implementing a new model architecture +- Testing performance across different hardware setups + +## Prerequisites +- Ollama server running locally with `ollama serve` on `127.0.0.1:11434` +## Usage and Examples + +>[!NOTE] +>All commands must be run from the root directory of the Ollama project. + +Basic syntax: +```bash +go test -bench=. ./benchmark/... -m $MODEL_NAME +``` + +Required flags: +- `-bench=.`: Run all benchmarks +- `-m`: Model name to benchmark + +Optional flags: +- `-count N`: Number of times to run the benchmark (useful for statistical analysis) +- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes) + +Common usage patterns: + +Single benchmark run with a model specified: +```bash +go test -bench=. ./benchmark/... -m llama3.3 +``` + +## Output metrics + +The benchmark reports several key metrics: + +- `gen_tok/s`: Generated tokens per second +- `prompt_tok/s`: Prompt processing tokens per second +- `ttft_ms`: Time to first token in milliseconds +- `load_ms`: Model load time in milliseconds +- `gen_tokens`: Total tokens generated +- `prompt_tokens`: Total prompt tokens processed + +Each benchmark runs two scenarios: +- Cold start: Model is loaded from disk for each test +- Warm start: Model is pre-loaded in memory + +Three prompt lengths are tested for each scenario: +- Short prompt (100 tokens) +- Medium prompt (500 tokens) +- Long prompt (1000 tokens)