mirror of
https://github.com/ollama/ollama.git
synced 2025-11-10 15:37:27 +01:00
282 lines
7.2 KiB
Go
282 lines
7.2 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"runtime"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
)
|
|
|
|
type flagOptions struct {
|
|
models *string
|
|
epochs *int
|
|
maxTokens *int
|
|
temperature *float64
|
|
seed *int
|
|
timeout *int
|
|
prompt *string
|
|
imageFile *string
|
|
keepAlive *float64
|
|
format *string
|
|
outputFile *string
|
|
debug *bool
|
|
verbose *bool
|
|
}
|
|
|
|
type Metrics struct {
|
|
Model string
|
|
Step string
|
|
Count int
|
|
Duration time.Duration
|
|
}
|
|
|
|
var once sync.Once
|
|
|
|
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
|
|
switch format {
|
|
case "benchstat":
|
|
if verbose {
|
|
printHeader := func() {
|
|
fmt.Printf("sysname: %s\n", runtime.GOOS)
|
|
fmt.Printf("machine: %s\n", runtime.GOARCH)
|
|
}
|
|
once.Do(printHeader)
|
|
}
|
|
for _, m := range metrics {
|
|
if m.Step == "generate" || m.Step == "prefill" {
|
|
if m.Count > 0 {
|
|
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
|
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
|
|
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n",
|
|
m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
|
|
} else {
|
|
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n",
|
|
m.Model, m.Step, m.Count)
|
|
}
|
|
} else {
|
|
var suffix string
|
|
if m.Step == "load" {
|
|
suffix = "/step=load"
|
|
}
|
|
fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n",
|
|
m.Model, suffix, m.Duration.Nanoseconds())
|
|
}
|
|
}
|
|
case "csv":
|
|
printHeader := func() {
|
|
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
|
|
fmt.Fprintln(w, strings.Join(headings, ","))
|
|
}
|
|
once.Do(printHeader)
|
|
|
|
for _, m := range metrics {
|
|
if m.Step == "generate" || m.Step == "prefill" {
|
|
var nsPerToken float64
|
|
var tokensPerSec float64
|
|
if m.Count > 0 {
|
|
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
|
|
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
|
|
}
|
|
fmt.Fprintf(w, "%s,%s,%d,%.2f,%.2f\n", m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
|
|
} else {
|
|
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
|
|
}
|
|
}
|
|
default:
|
|
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
|
|
}
|
|
}
|
|
|
|
func BenchmarkChat(fOpt flagOptions) error {
|
|
models := strings.Split(*fOpt.models, ",")
|
|
|
|
// todo - add multi-image support
|
|
var imgData api.ImageData
|
|
var err error
|
|
if *fOpt.imageFile != "" {
|
|
imgData, err = readImage(*fOpt.imageFile)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "ERROR: Couldn't read image '%s': %v\n", *fOpt.imageFile, err)
|
|
return err
|
|
}
|
|
}
|
|
|
|
if *fOpt.debug && imgData != nil {
|
|
fmt.Fprintf(os.Stderr, "Read file '%s'\n", *fOpt.imageFile)
|
|
}
|
|
|
|
client, err := api.ClientFromEnvironment()
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "ERROR: Couldn't create ollama client: %v\n", err)
|
|
return err
|
|
}
|
|
|
|
for _, model := range models {
|
|
for range *fOpt.epochs {
|
|
options := make(map[string]interface{})
|
|
if *fOpt.maxTokens > 0 {
|
|
options["num_predict"] = *fOpt.maxTokens
|
|
}
|
|
options["temperature"] = *fOpt.temperature
|
|
if fOpt.seed != nil && *fOpt.seed > 0 {
|
|
options["seed"] = *fOpt.seed
|
|
}
|
|
|
|
var keepAliveDuration *api.Duration
|
|
if *fOpt.keepAlive > 0 {
|
|
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
|
|
keepAliveDuration = &duration
|
|
}
|
|
|
|
req := &api.ChatRequest{
|
|
Model: model,
|
|
Messages: []api.Message{
|
|
{
|
|
Role: "user",
|
|
Content: *fOpt.prompt,
|
|
},
|
|
},
|
|
Options: options,
|
|
KeepAlive: keepAliveDuration,
|
|
}
|
|
|
|
if imgData != nil {
|
|
req.Messages[0].Images = []api.ImageData{imgData}
|
|
}
|
|
|
|
var responseMetrics *api.Metrics
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
|
|
defer cancel()
|
|
|
|
err = client.Chat(ctx, req, func(resp api.ChatResponse) error {
|
|
if *fOpt.debug {
|
|
fmt.Fprintf(os.Stderr, "%s", resp.Message.Content)
|
|
}
|
|
|
|
if resp.Done {
|
|
responseMetrics = &resp.Metrics
|
|
}
|
|
return nil
|
|
})
|
|
|
|
if *fOpt.debug {
|
|
fmt.Fprintln(os.Stderr)
|
|
}
|
|
|
|
if err != nil {
|
|
if ctx.Err() == context.DeadlineExceeded {
|
|
fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1)
|
|
continue
|
|
}
|
|
fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err)
|
|
continue
|
|
}
|
|
|
|
if responseMetrics == nil {
|
|
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
|
|
continue
|
|
}
|
|
|
|
metrics := []Metrics{
|
|
{
|
|
Model: model,
|
|
Step: "prefill",
|
|
Count: responseMetrics.PromptEvalCount,
|
|
Duration: responseMetrics.PromptEvalDuration,
|
|
},
|
|
{
|
|
Model: model,
|
|
Step: "generate",
|
|
Count: responseMetrics.EvalCount,
|
|
Duration: responseMetrics.EvalDuration,
|
|
},
|
|
{
|
|
Model: model,
|
|
Step: "load",
|
|
Count: 1,
|
|
Duration: responseMetrics.LoadDuration,
|
|
},
|
|
{
|
|
Model: model,
|
|
Step: "total",
|
|
Count: 1,
|
|
Duration: responseMetrics.TotalDuration,
|
|
},
|
|
}
|
|
|
|
OutputMetrics(os.Stdout, *fOpt.format, metrics, *fOpt.verbose)
|
|
|
|
if *fOpt.keepAlive > 0 {
|
|
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func readImage(filePath string) (api.ImageData, error) {
|
|
file, err := os.Open(filePath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer file.Close()
|
|
|
|
data, err := io.ReadAll(file)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return api.ImageData(data), nil
|
|
}
|
|
|
|
func main() {
|
|
fOpt := flagOptions{
|
|
models: flag.String("model", "", "Model to benchmark"),
|
|
epochs: flag.Int("epochs", 1, "Number of epochs (iterations) per model"),
|
|
maxTokens: flag.Int("max-tokens", 0, "Maximum tokens for model response"),
|
|
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
|
|
seed: flag.Int("seed", 0, "Random seed"),
|
|
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
|
|
prompt: flag.String("p", "Write a long story.", "Prompt to use"),
|
|
imageFile: flag.String("image", "", "Filename for an image to include"),
|
|
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
|
|
format: flag.String("format", "benchstat", "Output format [benchstat|csv] (default benchstat)"),
|
|
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
|
|
verbose: flag.Bool("v", false, "Show system information"),
|
|
debug: flag.Bool("debug", false, "Show debug information"),
|
|
}
|
|
|
|
flag.Usage = func() {
|
|
fmt.Fprintf(os.Stderr, "Usage: %s [OPTIONS]\n\n", os.Args[0])
|
|
fmt.Fprintf(os.Stderr, "Description:\n")
|
|
fmt.Fprintf(os.Stderr, " Model benchmarking tool with configurable parameters\n\n")
|
|
fmt.Fprintf(os.Stderr, "Options:\n")
|
|
flag.PrintDefaults()
|
|
fmt.Fprintf(os.Stderr, "\nExamples:\n")
|
|
fmt.Fprintf(os.Stderr, " ollama-bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n")
|
|
}
|
|
flag.Parse()
|
|
|
|
if !slices.Contains([]string{"benchstat", "csv"}, *fOpt.format) {
|
|
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
|
|
os.Exit(1)
|
|
}
|
|
|
|
if len(*fOpt.models) == 0 {
|
|
flag.Usage()
|
|
return
|
|
}
|
|
|
|
BenchmarkChat(fOpt)
|
|
}
|