mirror of
https://github.com/ollama/ollama.git
synced 2025-11-10 15:07:46 +01:00
updated tests to loop through every embedding mdoel
This commit is contained in:
@@ -303,13 +303,24 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEmbedTruncation(t *testing.T) {
|
func TestEmbedTruncation(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
// Using adaptive soft/hard timeouts to avoid a single global 2m context
|
||||||
|
softTimeout, hardTimeout := getTimeouts(t)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
model := "all-minilm"
|
for _, model := range libraryEmbedModels {
|
||||||
|
model := model
|
||||||
t.Run(model, func(t *testing.T) {
|
t.Run(model, func(t *testing.T) {
|
||||||
|
if time.Since(started) > softTimeout {
|
||||||
|
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Give each model its own budget to account for first-time pulls/loads
|
||||||
|
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
|
||||||
|
defer mcancel()
|
||||||
|
|
||||||
t.Run("truncation batch", func(t *testing.T) {
|
t.Run("truncation batch", func(t *testing.T) {
|
||||||
truncTrue := true
|
truncTrue := true
|
||||||
req := api.EmbedRequest{
|
req := api.EmbedRequest{
|
||||||
@@ -319,7 +330,7 @@ func TestEmbedTruncation(t *testing.T) {
|
|||||||
Options: map[string]any{"num_ctx": 30},
|
Options: map[string]any{"num_ctx": 30},
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := embedTestHelper(ctx, client, t, req)
|
res, err := embedTestHelper(mctx, client, t, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -335,7 +346,7 @@ func TestEmbedTruncation(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("runner token count accuracy", func(t *testing.T) {
|
t.Run("runner token count accuracy", func(t *testing.T) {
|
||||||
baseline := api.EmbedRequest{Model: model, Input: "test"}
|
baseline := api.EmbedRequest{Model: model, Input: "test"}
|
||||||
baseRes, err := embedTestHelper(ctx, client, t, baseline)
|
baseRes, err := embedTestHelper(mctx, client, t, baseline)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -344,7 +355,7 @@ func TestEmbedTruncation(t *testing.T) {
|
|||||||
Model: model,
|
Model: model,
|
||||||
Input: []string{"test", "test", "test"},
|
Input: []string{"test", "test", "test"},
|
||||||
}
|
}
|
||||||
batchRes, err := embedTestHelper(ctx, client, t, batch)
|
batchRes, err := embedTestHelper(mctx, client, t, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -357,13 +368,15 @@ func TestEmbedTruncation(t *testing.T) {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestEmbedStatusCode tests that errors from the embedding endpoint
|
// TestEmbedStatusCode tests that errors from the embedding endpoint
|
||||||
// properly preserve their HTTP status codes when returned to the client.
|
// properly preserve their HTTP status codes when returned to the client.
|
||||||
// This test specifically checks the error handling path in EmbedHandler
|
// This test specifically checks the error handling path in EmbedHandler
|
||||||
// where api.StatusError errors should maintain their original status code.
|
// where api.StatusError errors should maintain their original status code.
|
||||||
func TestEmbedStatusCode(t *testing.T) {
|
func TestEmbedStatusCode(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
softTimeout, hardTimeout := getTimeouts(t)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
@@ -371,8 +384,15 @@ func TestEmbedStatusCode(t *testing.T) {
|
|||||||
for _, model := range libraryEmbedModels {
|
for _, model := range libraryEmbedModels {
|
||||||
model := model
|
model := model
|
||||||
t.Run(model, func(t *testing.T) {
|
t.Run(model, func(t *testing.T) {
|
||||||
|
if time.Since(started) > softTimeout {
|
||||||
|
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||||
|
}
|
||||||
|
|
||||||
|
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
|
||||||
|
defer mcancel()
|
||||||
|
|
||||||
// Pull the model if needed
|
// Pull the model if needed
|
||||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
if err := PullIfMissing(mctx, client, model); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -387,7 +407,7 @@ func TestEmbedStatusCode(t *testing.T) {
|
|||||||
Options: map[string]any{"num_ctx": 10},
|
Options: map[string]any{"num_ctx": 10},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := embedTestHelper(ctx, client, t, req)
|
_, err := embedTestHelper(mctx, client, t, req)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error when truncate=false with long input")
|
t.Fatal("expected error when truncate=false with long input")
|
||||||
}
|
}
|
||||||
@@ -423,7 +443,7 @@ func TestEmbedStatusCode(t *testing.T) {
|
|||||||
Options: map[string]any{"num_ctx": 10},
|
Options: map[string]any{"num_ctx": 10},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := embedTestHelper(ctx, client, t, req)
|
_, err := embedTestHelper(mctx, client, t, req)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error when one input exceeds context with truncate=false")
|
t.Fatal("expected error when one input exceeds context with truncate=false")
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user