diff --git a/integration/embed_test.go b/integration/embed_test.go index df80e387a7..83722f1dec 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -303,67 +303,9 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req } func TestEmbedTruncation(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) - defer cancel() - client, _, cleanup := InitServerConnection(ctx, t) - defer cleanup() - - model := "all-minilm" - t.Run(model, func(t *testing.T) { - t.Run("truncation batch", func(t *testing.T) { - truncTrue := true - req := api.EmbedRequest{ - Model: model, - Input: []string{"short", strings.Repeat("long ", 100), "medium text"}, - Truncate: &truncTrue, - Options: map[string]any{"num_ctx": 30}, - } - - res, err := embedTestHelper(ctx, client, t, req) - if err != nil { - t.Fatal(err) - } - - if len(res.Embeddings) != 3 { - t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings)) - } - - if res.PromptEvalCount > 90 { - t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount) - } - }) - - t.Run("runner token count accuracy", func(t *testing.T) { - baseline := api.EmbedRequest{Model: model, Input: "test"} - baseRes, err := embedTestHelper(ctx, client, t, baseline) - if err != nil { - t.Fatal(err) - } - - batch := api.EmbedRequest{ - Model: model, - Input: []string{"test", "test", "test"}, - } - batchRes, err := embedTestHelper(ctx, client, t, batch) - if err != nil { - t.Fatal(err) - } - - expectedCount := baseRes.PromptEvalCount * 3 - if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 { - t.Fatalf("expected ~%d tokens (3 × %d), got %d", - expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount) - } - }) - }) -} - -// TestEmbedStatusCode tests that errors from the embedding endpoint -// properly preserve their HTTP status codes when returned to the client. -// This test specifically checks the error handling path in EmbedHandler -// where api.StatusError errors should maintain their original status code. -func TestEmbedStatusCode(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + // Using adaptive soft/hard timeouts to avoid a single global 2m context + softTimeout, hardTimeout := getTimeouts(t) + ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) defer cancel() client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() @@ -371,8 +313,86 @@ func TestEmbedStatusCode(t *testing.T) { for _, model := range libraryEmbedModels { model := model t.Run(model, func(t *testing.T) { + if time.Since(started) > softTimeout { + t.Skip("skipping remaining tests to avoid excessive runtime") + } + + // Give each model its own budget to account for first-time pulls/loads + mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute) + defer mcancel() + + t.Run("truncation batch", func(t *testing.T) { + truncTrue := true + req := api.EmbedRequest{ + Model: model, + Input: []string{"short", strings.Repeat("long ", 100), "medium text"}, + Truncate: &truncTrue, + Options: map[string]any{"num_ctx": 30}, + } + + res, err := embedTestHelper(mctx, client, t, req) + if err != nil { + t.Fatal(err) + } + + if len(res.Embeddings) != 3 { + t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings)) + } + + if res.PromptEvalCount > 90 { + t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount) + } + }) + + t.Run("runner token count accuracy", func(t *testing.T) { + baseline := api.EmbedRequest{Model: model, Input: "test"} + baseRes, err := embedTestHelper(mctx, client, t, baseline) + if err != nil { + t.Fatal(err) + } + + batch := api.EmbedRequest{ + Model: model, + Input: []string{"test", "test", "test"}, + } + batchRes, err := embedTestHelper(mctx, client, t, batch) + if err != nil { + t.Fatal(err) + } + + expectedCount := baseRes.PromptEvalCount * 3 + if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 { + t.Fatalf("expected ~%d tokens (3 × %d), got %d", + expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount) + } + }) + }) + } +} + +// TestEmbedStatusCode tests that errors from the embedding endpoint +// properly preserve their HTTP status codes when returned to the client. +// This test specifically checks the error handling path in EmbedHandler +// where api.StatusError errors should maintain their original status code. +func TestEmbedStatusCode(t *testing.T) { + softTimeout, hardTimeout := getTimeouts(t) + ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + for _, model := range libraryEmbedModels { + model := model + t.Run(model, func(t *testing.T) { + if time.Since(started) > softTimeout { + t.Skip("skipping remaining tests to avoid excessive runtime") + } + + mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute) + defer mcancel() + // Pull the model if needed - if err := PullIfMissing(ctx, client, model); err != nil { + if err := PullIfMissing(mctx, client, model); err != nil { t.Fatal(err) } @@ -387,7 +407,7 @@ func TestEmbedStatusCode(t *testing.T) { Options: map[string]any{"num_ctx": 10}, } - _, err := embedTestHelper(ctx, client, t, req) + _, err := embedTestHelper(mctx, client, t, req) if err == nil { t.Fatal("expected error when truncate=false with long input") } @@ -423,7 +443,7 @@ func TestEmbedStatusCode(t *testing.T) { Options: map[string]any{"num_ctx": 10}, } - _, err := embedTestHelper(ctx, client, t, req) + _, err := embedTestHelper(mctx, client, t, req) if err == nil { t.Fatal("expected error when one input exceeds context with truncate=false") }