updated tests to loop through every embedding mdoel

This commit is contained in:
nicole pardal
2025-10-30 11:15:07 -07:00
parent 2775db67bc
commit 0d9c873227

View File

@@ -303,67 +303,9 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
} }
func TestEmbedTruncation(t *testing.T) { func TestEmbedTruncation(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) // Using adaptive soft/hard timeouts to avoid a single global 2m context
defer cancel() softTimeout, hardTimeout := getTimeouts(t)
client, _, cleanup := InitServerConnection(ctx, t) ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cleanup()
model := "all-minilm"
t.Run(model, func(t *testing.T) {
t.Run("truncation batch", func(t *testing.T) {
truncTrue := true
req := api.EmbedRequest{
Model: model,
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 30},
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if len(res.Embeddings) != 3 {
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
}
if res.PromptEvalCount > 90 {
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
}
})
t.Run("runner token count accuracy", func(t *testing.T) {
baseline := api.EmbedRequest{Model: model, Input: "test"}
baseRes, err := embedTestHelper(ctx, client, t, baseline)
if err != nil {
t.Fatal(err)
}
batch := api.EmbedRequest{
Model: model,
Input: []string{"test", "test", "test"},
}
batchRes, err := embedTestHelper(ctx, client, t, batch)
if err != nil {
t.Fatal(err)
}
expectedCount := baseRes.PromptEvalCount * 3
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
}
})
})
}
// TestEmbedStatusCode tests that errors from the embedding endpoint
// properly preserve their HTTP status codes when returned to the client.
// This test specifically checks the error handling path in EmbedHandler
// where api.StatusError errors should maintain their original status code.
func TestEmbedStatusCode(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
@@ -371,8 +313,86 @@ func TestEmbedStatusCode(t *testing.T) {
for _, model := range libraryEmbedModels { for _, model := range libraryEmbedModels {
model := model model := model
t.Run(model, func(t *testing.T) { t.Run(model, func(t *testing.T) {
if time.Since(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime")
}
// Give each model its own budget to account for first-time pulls/loads
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
defer mcancel()
t.Run("truncation batch", func(t *testing.T) {
truncTrue := true
req := api.EmbedRequest{
Model: model,
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 30},
}
res, err := embedTestHelper(mctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if len(res.Embeddings) != 3 {
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
}
if res.PromptEvalCount > 90 {
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
}
})
t.Run("runner token count accuracy", func(t *testing.T) {
baseline := api.EmbedRequest{Model: model, Input: "test"}
baseRes, err := embedTestHelper(mctx, client, t, baseline)
if err != nil {
t.Fatal(err)
}
batch := api.EmbedRequest{
Model: model,
Input: []string{"test", "test", "test"},
}
batchRes, err := embedTestHelper(mctx, client, t, batch)
if err != nil {
t.Fatal(err)
}
expectedCount := baseRes.PromptEvalCount * 3
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
}
})
})
}
}
// TestEmbedStatusCode tests that errors from the embedding endpoint
// properly preserve their HTTP status codes when returned to the client.
// This test specifically checks the error handling path in EmbedHandler
// where api.StatusError errors should maintain their original status code.
func TestEmbedStatusCode(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, model := range libraryEmbedModels {
model := model
t.Run(model, func(t *testing.T) {
if time.Since(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime")
}
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
defer mcancel()
// Pull the model if needed // Pull the model if needed
if err := PullIfMissing(ctx, client, model); err != nil { if err := PullIfMissing(mctx, client, model); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -387,7 +407,7 @@ func TestEmbedStatusCode(t *testing.T) {
Options: map[string]any{"num_ctx": 10}, Options: map[string]any{"num_ctx": 10},
} }
_, err := embedTestHelper(ctx, client, t, req) _, err := embedTestHelper(mctx, client, t, req)
if err == nil { if err == nil {
t.Fatal("expected error when truncate=false with long input") t.Fatal("expected error when truncate=false with long input")
} }
@@ -423,7 +443,7 @@ func TestEmbedStatusCode(t *testing.T) {
Options: map[string]any{"num_ctx": 10}, Options: map[string]any{"num_ctx": 10},
} }
_, err := embedTestHelper(ctx, client, t, req) _, err := embedTestHelper(mctx, client, t, req)
if err == nil { if err == nil {
t.Fatal("expected error when one input exceeds context with truncate=false") t.Fatal("expected error when one input exceeds context with truncate=false")
} }