From 90698c7d15e4cd73b5e518f3be85d74d857af764 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 28 Jan 2025 14:55:03 -0800 Subject: [PATCH] benchmark: new Go runner --- benchmark/README.md | 25 +++ benchmark/new_runner.sh | 72 ++++++ benchmark/new_runner_benchmark_test.go | 173 +++++++++++++++ benchmark/server_benchmark_test.go | 293 +++++++++++++++++++++++++ 4 files changed, 563 insertions(+) create mode 100644 benchmark/README.md create mode 100644 benchmark/new_runner.sh create mode 100644 benchmark/new_runner_benchmark_test.go create mode 100644 benchmark/server_benchmark_test.go diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 000000000..d3309059c --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,25 @@ +# Benchmark + +Performance benchmarking for Ollama. + +## Prerequisites +- Ollama server running locally (`127.0.0.1:11434`) +- Desired models pre-downloaded (e.g., `llama3.2:1b`) + +## Run Benchmark +```bash +# Run all tests +go test -bench=. -timeout 30m ./... +``` + +## New Runner Benchmark +```bash +go test -bench=Runner +``` + +or to test multiple models: +```bash +# run this from within the benchmark directory +# requires: llama3.2:1b, llama3.1:8b, llama3.3:70b +sh new_runner.sh +``` diff --git a/benchmark/new_runner.sh b/benchmark/new_runner.sh new file mode 100644 index 000000000..371261856 --- /dev/null +++ b/benchmark/new_runner.sh @@ -0,0 +1,72 @@ +#!/bin/bash + +kill_process_tree() { + local pid=$1 + # Get all child processes using pgrep + local children=$(pgrep -P $pid) + + # Kill children first + for child in $children; do + kill_process_tree $child + done + + # Kill the parent process + kill -9 $pid 2>/dev/null || true +} + +# Function to run the runner and benchmark for a given model +run_benchmark() { + local model=$1 + + echo "Starting runner with model: $model" + # Start the runner in background and save its PID + go run ../cmd/runner/main.go --new-runner -model "$model" & + runner_pid=$! + + # Wait for the runner to initialize (adjust sleep time as needed) + sleep 5 + + echo "Running benchmark..." + # Run test and wait for it to complete + go test -bench=Runner + test_exit_code=$? + + echo "Stopping runner process..." + # Kill the runner process and all its children + kill_process_tree $runner_pid + + # Wait for the process to fully terminate + wait $runner_pid 2>/dev/null || true + + # Make sure no processes are still listening on port 8080 + lsof -t -i:8080 | xargs kill -9 2>/dev/null || true + + # Additional sleep to ensure port is freed + sleep 2 + + # Check if test failed + if [ $test_exit_code -ne 0 ]; then + echo "Warning: Benchmark test failed with exit code $test_exit_code" + fi + + echo "Benchmark complete for model: $model" + echo "----------------------------------------" +} + + +HOME_DIR="$HOME" +# llama3.2:1b: ~/.ollama/models/blobs/sha256-74701a8c35f6c8d9a4b91f3f3497643001d63e0c7a84e085bed452548fa88d45 +# llama3.1:8b: ~/.ollama/models/blobs/sha256-667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29 +# llama3.3:70b: ~/.ollama/models/blobs/sha256-4824460d29f2058aaf6e1118a63a7a197a09bed509f0e7d4e2efb1ee273b447d +models=( + "${HOME_DIR}/.ollama/models/blobs/sha256-74701a8c35f6c8d9a4b91f3f3497643001d63e0c7a84e085bed452548fa88d45" + "${HOME_DIR}/.ollama/models/blobs/sha256-667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29" + # "${HOME_DIR}/.ollama/models/blobs/sha256-4824460d29f2058aaf6e1118a63a7a197a09bed509f0e7d4e2efb1ee273b447d" +) + +# Run benchmarks for each model +for model in "${models[@]}"; do + run_benchmark "$model" +done + +echo "All benchmarks completed!" diff --git a/benchmark/new_runner_benchmark_test.go b/benchmark/new_runner_benchmark_test.go new file mode 100644 index 000000000..e93562b84 --- /dev/null +++ b/benchmark/new_runner_benchmark_test.go @@ -0,0 +1,173 @@ +package benchmark + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "testing" + "time" +) + +const ( + runnerURL = "http://localhost:8080" + warmupPrompts = 2 // Number of warm-up requests per test case + warmupTokens = 50 // Smaller token count for warm-up requests +) + +var runnerMetrics []BenchmarkMetrics + +// CompletionRequest represents the request body for the completion endpoint +type CompletionRequest struct { + Prompt string `json:"prompt"` + NumPredict int `json:"n_predict"` +} + +// CompletionResponse represents a single response chunk from the streaming API +type CompletionResponse struct { + Content string `json:"content"` + Stop bool `json:"stop"` + Timings struct { + PredictedN int `json:"predicted_n"` + PredictedMs int `json:"predicted_ms"` + PromptN int `json:"prompt_n"` + PromptMs int `json:"prompt_ms"` + } `json:"timings"` +} + +// warmUp performs warm-up requests before the actual benchmark +func warmUp(b *testing.B, tt TestCase) { + b.Logf("Warming up for test case %s", tt.name) + warmupTest := TestCase{ + name: tt.name + "_warmup", + prompt: tt.prompt, + maxTokens: warmupTokens, + } + + for i := 0; i < warmupPrompts; i++ { + runCompletion(context.Background(), warmupTest, b) + time.Sleep(100 * time.Millisecond) // Brief pause between warm-up requests + } + b.Logf("Warm-up complete") +} + +func BenchmarkRunnerInference(b *testing.B) { + b.Logf("Starting benchmark suite") + + // Verify server availability + if _, err := http.Get(runnerURL + "/health"); err != nil { + b.Fatalf("Runner unavailable: %v", err) + } + b.Log("Runner available") + + tests := []TestCase{ + { + name: "short_prompt", + prompt: formatPrompt("Write a long story"), + maxTokens: 100, + }, + { + name: "medium_prompt", + prompt: formatPrompt("Write a detailed economic analysis"), + maxTokens: 500, + }, + { + name: "long_prompt", + prompt: formatPrompt("Write a comprehensive AI research paper"), + maxTokens: 1000, + }, + } + + // Register cleanup handler for results reporting + b.Cleanup(func() { reportMetrics(metrics) }) + + // Main benchmark loop + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + // Perform warm-up requests + warmUp(b, tt) + + // Wait a bit after warm-up before starting the actual benchmark + time.Sleep(500 * time.Millisecond) + + m := make([]BenchmarkMetrics, b.N) + + for i := 0; i < b.N; i++ { + b.ResetTimer() + m[i] = runCompletion(context.Background(), tt, b) + } + metrics = append(metrics, m...) + }) + } +} + +func formatPrompt(text string) string { + return fmt.Sprintf("<|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", text) +} + +func runCompletion(ctx context.Context, tt TestCase, b *testing.B) BenchmarkMetrics { + start := time.Now() + var ttft time.Duration + var tokens int + lastToken := start + + // Create request body + reqBody := CompletionRequest{ + Prompt: tt.prompt, + NumPredict: tt.maxTokens, + } + jsonData, err := json.Marshal(reqBody) + if err != nil { + b.Fatalf("Failed to marshal request: %v", err) + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, "POST", runnerURL+"/completion", bytes.NewBuffer(jsonData)) + if err != nil { + b.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + + // Execute request + resp, err := http.DefaultClient.Do(req) + if err != nil { + b.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + // Process streaming response + decoder := json.NewDecoder(resp.Body) + for { + var chunk CompletionResponse + if err := decoder.Decode(&chunk); err != nil { + if err == io.EOF { + break + } + b.Fatalf("Failed to decode response: %v", err) + } + + if ttft == 0 && chunk.Content != "" { + ttft = time.Since(start) + } + + if chunk.Content != "" { + tokens++ + lastToken = time.Now() + } + + if chunk.Stop { + break + } + } + + totalTime := lastToken.Sub(start) + return BenchmarkMetrics{ + testName: tt.name, + ttft: ttft, + totalTime: totalTime, + totalTokens: tokens, + tokensPerSecond: float64(tokens) / totalTime.Seconds(), + } +} diff --git a/benchmark/server_benchmark_test.go b/benchmark/server_benchmark_test.go new file mode 100644 index 000000000..65d056b76 --- /dev/null +++ b/benchmark/server_benchmark_test.go @@ -0,0 +1,293 @@ +// Package benchmark provides tools for performance testing of Ollama inference server and supported models. +package benchmark + +import ( + "context" + "fmt" + "net/http" + "net/url" + "os" + "testing" + "text/tabwriter" + "time" + + "github.com/ollama/ollama/api" +) + +// ServerURL is the default Ollama server URL for benchmarking +const serverURL = "http://127.0.0.1:11434" + +// metrics collects all benchmark results for final reporting +var metrics []BenchmarkMetrics + +// models contains the list of model names to benchmark +var models = []string{ + "llama3.2:1b", + // "qwen2.5:7b", + // "llama3.3:70b", +} + +// TestCase defines a benchmark test scenario with prompt characteristics +type TestCase struct { + name string // Human-readable test name + prompt string // Input prompt text + maxTokens int // Maximum tokens to generate +} + +// BenchmarkMetrics contains performance measurements for a single test run +type BenchmarkMetrics struct { + model string // Model being tested + scenario string // cold_start or warm_start + testName string // Name of the test case + ttft time.Duration // Time To First Token (TTFT) + totalTime time.Duration // Total time for complete response + totalTokens int // Total generated tokens + tokensPerSecond float64 // Calculated throughput +} + +// ScenarioType defines the initialization state for benchmarking +type ScenarioType int + +const ( + ColdStart ScenarioType = iota // Model is loaded from cold state + WarmStart // Model is already loaded in memory +) + +// String implements fmt.Stringer for ScenarioType +func (s ScenarioType) String() string { + return [...]string{"cold_start", "warm_start"}[s] +} + +// BenchmarkServerInference is the main entry point for benchmarking Ollama inference performance. +// It tests all configured models with different prompt lengths and start scenarios. +func BenchmarkServerInference(b *testing.B) { + b.Logf("Starting benchmark suite with %d models", len(models)) + + // Verify server availability + if _, err := http.Get(serverURL + "/api/version"); err != nil { + b.Fatalf("Server unavailable: %v", err) + } + b.Log("Server available") + + tests := []TestCase{ + {"short_prompt", "Write a long story", 100}, + {"medium_prompt", "Write a detailed economic analysis", 500}, + {"long_prompt", "Write a comprehensive AI research paper", 1000}, + } + + // Register cleanup handler for results reporting + b.Cleanup(func() { reportMetrics(metrics) }) + + // Main benchmark loop + for _, model := range models { + client := api.NewClient(mustParse(serverURL), http.DefaultClient) + // Verify model availability + if _, err := client.Show(context.Background(), &api.ShowRequest{Model: model}); err != nil { + b.Fatalf("Model unavailable: %v", err) + } + + for _, tt := range tests { + testName := fmt.Sprintf("%s/%s/%s", model, ColdStart, tt.name) + b.Run(testName, func(b *testing.B) { + m := runBenchmark(b, tt, model, ColdStart, client) + metrics = append(metrics, m...) + }) + } + + for _, tt := range tests { + testName := fmt.Sprintf("%s/%s/%s", model, WarmStart, tt.name) + b.Run(testName, func(b *testing.B) { + m := runBenchmark(b, tt, model, WarmStart, client) + metrics = append(metrics, m...) + }) + } + } +} + +// runBenchmark executes multiple iterations of a specific test case and scenario. +// Returns collected metrics for all iterations. +func runBenchmark(b *testing.B, tt TestCase, model string, scenario ScenarioType, client *api.Client) []BenchmarkMetrics { + results := make([]BenchmarkMetrics, b.N) + + // Run benchmark iterations + for i := 0; i < b.N; i++ { + switch scenario { + case WarmStart: + // Pre-warm the model by generating some tokens + for i := 0; i < 2; i++ { + client.Generate( + context.Background(), + &api.GenerateRequest{ + Model: model, + Prompt: tt.prompt, + Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1}, + }, + func(api.GenerateResponse) error { return nil }, + ) + } + case ColdStart: + unloadModel(client, model, b) + } + b.ResetTimer() + + results[i] = runSingleIteration(context.Background(), client, tt, model, b) + results[i].scenario = scenario.String() + } + return results +} + +// unloadModel forces model unloading using KeepAlive: -1 parameter. +// Includes short delay to ensure unloading completes before next test. +func unloadModel(client *api.Client, model string, b *testing.B) { + req := &api.GenerateRequest{ + Model: model, + KeepAlive: &api.Duration{Duration: 0}, + } + if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil { + b.Logf("Unload error: %v", err) + } + time.Sleep(100 * time.Millisecond) +} + +// runSingleIteration measures performance metrics for a single inference request. +// Captures TTFT, total generation time, and calculates tokens/second. +func runSingleIteration(ctx context.Context, client *api.Client, tt TestCase, model string, b *testing.B) BenchmarkMetrics { + start := time.Now() + var ttft time.Duration + var tokens int + lastToken := start + + req := &api.GenerateRequest{ + Model: model, + Prompt: tt.prompt, + Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1}, + } + + if b != nil { + b.Logf("Prompt length: %d chars", len(tt.prompt)) + } + + // Execute generation request with metrics collection + client.Generate(ctx, req, func(resp api.GenerateResponse) error { + if ttft == 0 { + ttft = time.Since(start) + } + if resp.Response != "" { + tokens++ + lastToken = time.Now() + } + return nil + }) + + totalTime := lastToken.Sub(start) + return BenchmarkMetrics{ + model: model, + testName: tt.name, + ttft: ttft, + totalTime: totalTime, + totalTokens: tokens, + tokensPerSecond: float64(tokens) / totalTime.Seconds(), + } +} + +// reportMetrics processes collected metrics and prints formatted results. +// Generates both human-readable tables and CSV output with averaged statistics. +func reportMetrics(results []BenchmarkMetrics) { + if len(results) == 0 { + return + } + + // Aggregate results by test case + type statsKey struct { + model string + scenario string + testName string + } + stats := make(map[statsKey]*struct { + ttftSum time.Duration + totalTimeSum time.Duration + tokensSum int + iterations int + }) + + for _, m := range results { + key := statsKey{m.model, m.scenario, m.testName} + if _, exists := stats[key]; !exists { + stats[key] = &struct { + ttftSum time.Duration + totalTimeSum time.Duration + tokensSum int + iterations int + }{} + } + + stats[key].ttftSum += m.ttft + stats[key].totalTimeSum += m.totalTime + stats[key].tokensSum += m.totalTokens + stats[key].iterations++ + } + + // Calculate averages + var averaged []BenchmarkMetrics + for key, data := range stats { + count := data.iterations + averaged = append(averaged, BenchmarkMetrics{ + model: key.model, + scenario: key.scenario, + testName: key.testName, + ttft: data.ttftSum / time.Duration(count), + totalTime: data.totalTimeSum / time.Duration(count), + totalTokens: data.tokensSum / count, + tokensPerSecond: float64(data.tokensSum) / data.totalTimeSum.Seconds(), + }) + } + + // Print formatted results + printTableResults(averaged) + printCSVResults(averaged) +} + +// printTableResults displays averaged metrics in a formatted table +func printTableResults(averaged []BenchmarkMetrics) { + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "\nAVERAGED BENCHMARK RESULTS") + fmt.Fprintln(w, "Model\tScenario\tTest Name\tTTFT (ms)\tTotal Time (ms)\tTokens\tTokens/sec") + for _, m := range averaged { + fmt.Fprintf(w, "%s\t%s\t%s\t%.2f\t%.2f\t%d\t%.2f\n", + m.model, + m.scenario, + m.testName, + float64(m.ttft.Milliseconds()), + float64(m.totalTime.Milliseconds()), + m.totalTokens, + m.tokensPerSecond, + ) + } + w.Flush() +} + +// printCSVResults outputs averaged metrics in CSV format +func printCSVResults(averaged []BenchmarkMetrics) { + fmt.Println("\nCSV OUTPUT") + fmt.Println("model,scenario,test_name,ttft_ms,total_ms,tokens,tokens_per_sec") + for _, m := range averaged { + fmt.Printf("%s,%s,%s,%.2f,%.2f,%d,%.2f\n", + m.model, + m.scenario, + m.testName, + float64(m.ttft.Milliseconds()), + float64(m.totalTime.Milliseconds()), + m.totalTokens, + m.tokensPerSecond, + ) + } +} + +// mustParse is a helper function to parse URLs with panic on error +func mustParse(rawURL string) *url.URL { + u, err := url.Parse(rawURL) + if err != nil { + panic(err) + } + return u +}