perf: build graph for next batch async to keep GPU busy (#11863)

* perf: build graph for next batch in parallel to keep GPU busy

This refactors the main run loop of the ollama runner to perform the main GPU
intensive tasks (Compute+Floats) in a go routine so we can prepare the next
batch in parallel to reduce the amount of time the GPU stalls waiting for the
next batch of work.

* tests: tune integration tests for ollama engine

This tunes the integration tests to focus more on models supported
by the new engine.
This commit is contained in:
Daniel Hiltgen
2025-08-29 14:20:28 -07:00
committed by GitHub
parent ead4a9a1d0
commit 517807cdf2
20 changed files with 591 additions and 235 deletions

View File

@@ -13,12 +13,12 @@ import (
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api"
)
func TestMaxQueue(t *testing.T) {
t.Skip("this test needs to be re-evaluated to use a proper embedding model")
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
return
@@ -45,7 +45,9 @@ func TestMaxQueue(t *testing.T) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model))
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
// Context for the worker threads so we can shut them down
// embedCtx, embedCancel := context.WithCancel(ctx)
@@ -89,7 +91,9 @@ func TestMaxQueue(t *testing.T) {
switch {
case genErr == nil:
successCount++
require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable
if len(resp.Embedding) < 5 { // somewhat arbitrary, but sufficient to be reasonable
t.Fatalf("embeddings shorter than expected: %d", len(resp.Embedding))
}
case errors.Is(genErr, context.Canceled):
canceledCount++
case strings.Contains(genErr.Error(), "busy"):
@@ -97,7 +101,9 @@ func TestMaxQueue(t *testing.T) {
case strings.Contains(genErr.Error(), "connection reset by peer"):
resetByPeerCount++
default:
require.NoError(t, genErr, "%d request failed", i)
if genErr != nil {
t.Fatalf("%d request failed", i)
}
}
slog.Info("embed finished", "id", i)
@@ -108,8 +114,13 @@ func TestMaxQueue(t *testing.T) {
embedwg.Wait()
slog.Info("embeds completed", "success", successCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?")
require.True(t, busyCount > 0, "no requests hit busy error but some should have")
require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout")
if resetByPeerCount != 0 {
t.Fatalf("Connections reset by peer, have you updated your fd and socket limits? %d", resetByPeerCount)
}
if busyCount == 0 {
t.Fatalf("no requests hit busy error but some should have")
}
if canceledCount > 0 {
t.Fatalf("no requests should have been canceled due to timeout %d", canceledCount)
}
}