From 2775db67bcdb7bcbce12e4685fa20351d2211877 Mon Sep 17 00:00:00 2001 From: nicole pardal Date: Thu, 30 Oct 2025 10:49:22 -0700 Subject: [PATCH] Fixed failing test --- integration/embed_test.go | 300 +++++++++++++++----------------------- 1 file changed, 114 insertions(+), 186 deletions(-) diff --git a/integration/embed_test.go b/integration/embed_test.go index 29c5ff9aa1..df80e387a7 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -303,135 +303,58 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req } func TestEmbedTruncation(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() - t.Run("single input token count", func(t *testing.T) { - req := api.EmbedRequest{ - Model: "all-minilm", - Input: "why is the sky blue?", - } + model := "all-minilm" + t.Run(model, func(t *testing.T) { + t.Run("truncation batch", func(t *testing.T) { + truncTrue := true + req := api.EmbedRequest{ + Model: model, + Input: []string{"short", strings.Repeat("long ", 100), "medium text"}, + Truncate: &truncTrue, + Options: map[string]any{"num_ctx": 30}, + } - res, err := embedTestHelper(ctx, client, t, req) - if err != nil { - t.Fatal(err) - } + res, err := embedTestHelper(ctx, client, t, req) + if err != nil { + t.Fatal(err) + } - if res.PromptEvalCount <= 0 { - t.Fatalf("expected positive token count, got %d", res.PromptEvalCount) - } - }) + if len(res.Embeddings) != 3 { + t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings)) + } - t.Run("batch parallel token counting", func(t *testing.T) { - req := api.EmbedRequest{ - Model: "all-minilm", - Input: []string{"cat", "dog and mouse", "bird"}, - } + if res.PromptEvalCount > 90 { + t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount) + } + }) - res, err := embedTestHelper(ctx, client, t, req) - if err != nil { - t.Fatal(err) - } + t.Run("runner token count accuracy", func(t *testing.T) { + baseline := api.EmbedRequest{Model: model, Input: "test"} + baseRes, err := embedTestHelper(ctx, client, t, baseline) + if err != nil { + t.Fatal(err) + } - if len(res.Embeddings) != 3 { - t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings)) - } + batch := api.EmbedRequest{ + Model: model, + Input: []string{"test", "test", "test"}, + } + batchRes, err := embedTestHelper(ctx, client, t, batch) + if err != nil { + t.Fatal(err) + } - if res.PromptEvalCount <= 0 { - t.Fatalf("expected positive token count, got %d", res.PromptEvalCount) - } - }) - - t.Run("truncation single input", func(t *testing.T) { - truncTrue := true - longInput := strings.Repeat("word ", 100) - - req := api.EmbedRequest{ - Model: "all-minilm", - Input: longInput, - Truncate: &truncTrue, - Options: map[string]any{"num_ctx": 50}, - } - - res, err := embedTestHelper(ctx, client, t, req) - if err != nil { - t.Fatal(err) - } - - if res.PromptEvalCount > 50 { - t.Fatalf("expected tokens <= 50 after truncation, got %d", res.PromptEvalCount) - } - - if res.PromptEvalCount == 0 { - t.Fatal("expected non-zero token count after truncation") - } - }) - - t.Run("truncation batch", func(t *testing.T) { - truncTrue := true - req := api.EmbedRequest{ - Model: "all-minilm", - Input: []string{"short", strings.Repeat("long ", 100), "medium text"}, - Truncate: &truncTrue, - Options: map[string]any{"num_ctx": 30}, - } - - res, err := embedTestHelper(ctx, client, t, req) - if err != nil { - t.Fatal(err) - } - - if len(res.Embeddings) != 3 { - t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings)) - } - - if res.PromptEvalCount > 90 { - t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount) - } - }) - - t.Run("truncate false error", func(t *testing.T) { - truncFalse := false - req := api.EmbedRequest{ - Model: "all-minilm", - Input: strings.Repeat("word ", 100), - Truncate: &truncFalse, - Options: map[string]any{"num_ctx": 10}, - } - - _, err := embedTestHelper(ctx, client, t, req) - if err == nil { - t.Fatal("expected error when truncate=false with long input") - } - - if !strings.Contains(err.Error(), "exceeds maximum context length") { - t.Fatalf("expected context length error, got: %v", err) - } - }) - - t.Run("runner token count accuracy", func(t *testing.T) { - baseline := api.EmbedRequest{Model: "all-minilm", Input: "test"} - baseRes, err := embedTestHelper(ctx, client, t, baseline) - if err != nil { - t.Fatal(err) - } - - batch := api.EmbedRequest{ - Model: "all-minilm", - Input: []string{"test", "test", "test"}, - } - batchRes, err := embedTestHelper(ctx, client, t, batch) - if err != nil { - t.Fatal(err) - } - - expectedCount := baseRes.PromptEvalCount * 3 - if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 { - t.Fatalf("expected ~%d tokens (3 × %d), got %d", - expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount) - } + expectedCount := baseRes.PromptEvalCount * 3 + if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 { + t.Fatalf("expected ~%d tokens (3 × %d), got %d", + expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount) + } + }) }) } @@ -445,72 +368,77 @@ func TestEmbedStatusCode(t *testing.T) { client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() - // Pull the model if needed - if err := PullIfMissing(ctx, client, "all-minilm"); err != nil { - t.Fatal(err) + for _, model := range libraryEmbedModels { + model := model + t.Run(model, func(t *testing.T) { + // Pull the model if needed + if err := PullIfMissing(ctx, client, model); err != nil { + t.Fatal(err) + } + + t.Run("truncation error status code", func(t *testing.T) { + truncFalse := false + longInput := strings.Repeat("word ", 100) + + req := api.EmbedRequest{ + Model: model, + Input: longInput, + Truncate: &truncFalse, + Options: map[string]any{"num_ctx": 10}, + } + + _, err := embedTestHelper(ctx, client, t, req) + if err == nil { + t.Fatal("expected error when truncate=false with long input") + } + + // Check that it's a StatusError with the correct status code + var statusErr api.StatusError + if !errors.As(err, &statusErr) { + t.Fatalf("expected api.StatusError, got %T: %v", err, err) + } + + // The error should be a 4xx client error (likely 400 Bad Request) + // not a 500 Internal Server Error + if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 { + t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode) + } + + // Verify the error message is meaningful + if !strings.Contains(err.Error(), "context length") { + t.Errorf("expected error message to mention context length, got: %v", err) + } + }) + + t.Run("batch truncation error status code", func(t *testing.T) { + truncFalse := false + req := api.EmbedRequest{ + Model: model, + Input: []string{ + "short input", + strings.Repeat("very long input ", 100), + "another short input", + }, + Truncate: &truncFalse, + Options: map[string]any{"num_ctx": 10}, + } + + _, err := embedTestHelper(ctx, client, t, req) + if err == nil { + t.Fatal("expected error when one input exceeds context with truncate=false") + } + + // Check that it's a StatusError with the correct status code + var statusErr api.StatusError + if !errors.As(err, &statusErr) { + t.Fatalf("expected api.StatusError, got %T: %v", err, err) + } + + // The error should be a 4xx client error, not a 500 Internal Server Error + if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 { + t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode) + } + }) + }) } - - t.Run("truncation error status code", func(t *testing.T) { - truncFalse := false - longInput := strings.Repeat("word ", 100) - - req := api.EmbedRequest{ - Model: "all-minilm", - Input: longInput, - Truncate: &truncFalse, - Options: map[string]any{"num_ctx": 10}, - } - - _, err := embedTestHelper(ctx, client, t, req) - if err == nil { - t.Fatal("expected error when truncate=false with long input") - } - - // Check that it's a StatusError with the correct status code - var statusErr api.StatusError - if !errors.As(err, &statusErr) { - t.Fatalf("expected api.StatusError, got %T: %v", err, err) - } - - // The error should be a 4xx client error (likely 400 Bad Request) - // not a 500 Internal Server Error - if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 { - t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode) - } - - // Verify the error message is meaningful - if !strings.Contains(err.Error(), "context length") { - t.Errorf("expected error message to mention context length, got: %v", err) - } - }) - - t.Run("batch truncation error status code", func(t *testing.T) { - truncFalse := false - req := api.EmbedRequest{ - Model: "all-minilm", - Input: []string{ - "short input", - strings.Repeat("very long input ", 100), - "another short input", - }, - Truncate: &truncFalse, - Options: map[string]any{"num_ctx": 10}, - } - - _, err := embedTestHelper(ctx, client, t, req) - if err == nil { - t.Fatal("expected error when one input exceeds context with truncate=false") - } - - // Check that it's a StatusError with the correct status code - var statusErr api.StatusError - if !errors.As(err, &statusErr) { - t.Fatalf("expected api.StatusError, got %T: %v", err, err) - } - - // The error should be a 4xx client error, not a 500 Internal Server Error - if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 { - t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode) - } - }) }