mirror of
https://github.com/ollama/ollama.git
synced 2025-11-10 22:20:14 +01:00
Fixed failing test
This commit is contained in:
@@ -303,135 +303,58 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEmbedTruncation(t *testing.T) {
|
func TestEmbedTruncation(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
t.Run("single input token count", func(t *testing.T) {
|
model := "all-minilm"
|
||||||
req := api.EmbedRequest{
|
t.Run(model, func(t *testing.T) {
|
||||||
Model: "all-minilm",
|
t.Run("truncation batch", func(t *testing.T) {
|
||||||
Input: "why is the sky blue?",
|
truncTrue := true
|
||||||
}
|
req := api.EmbedRequest{
|
||||||
|
Model: model,
|
||||||
|
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
|
||||||
|
Truncate: &truncTrue,
|
||||||
|
Options: map[string]any{"num_ctx": 30},
|
||||||
|
}
|
||||||
|
|
||||||
res, err := embedTestHelper(ctx, client, t, req)
|
res, err := embedTestHelper(ctx, client, t, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.PromptEvalCount <= 0 {
|
if len(res.Embeddings) != 3 {
|
||||||
t.Fatalf("expected positive token count, got %d", res.PromptEvalCount)
|
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
|
||||||
}
|
}
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("batch parallel token counting", func(t *testing.T) {
|
if res.PromptEvalCount > 90 {
|
||||||
req := api.EmbedRequest{
|
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
|
||||||
Model: "all-minilm",
|
}
|
||||||
Input: []string{"cat", "dog and mouse", "bird"},
|
})
|
||||||
}
|
|
||||||
|
|
||||||
res, err := embedTestHelper(ctx, client, t, req)
|
t.Run("runner token count accuracy", func(t *testing.T) {
|
||||||
if err != nil {
|
baseline := api.EmbedRequest{Model: model, Input: "test"}
|
||||||
t.Fatal(err)
|
baseRes, err := embedTestHelper(ctx, client, t, baseline)
|
||||||
}
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
if len(res.Embeddings) != 3 {
|
batch := api.EmbedRequest{
|
||||||
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
|
Model: model,
|
||||||
}
|
Input: []string{"test", "test", "test"},
|
||||||
|
}
|
||||||
|
batchRes, err := embedTestHelper(ctx, client, t, batch)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
if res.PromptEvalCount <= 0 {
|
expectedCount := baseRes.PromptEvalCount * 3
|
||||||
t.Fatalf("expected positive token count, got %d", res.PromptEvalCount)
|
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
|
||||||
}
|
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
|
||||||
})
|
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
|
||||||
|
}
|
||||||
t.Run("truncation single input", func(t *testing.T) {
|
})
|
||||||
truncTrue := true
|
|
||||||
longInput := strings.Repeat("word ", 100)
|
|
||||||
|
|
||||||
req := api.EmbedRequest{
|
|
||||||
Model: "all-minilm",
|
|
||||||
Input: longInput,
|
|
||||||
Truncate: &truncTrue,
|
|
||||||
Options: map[string]any{"num_ctx": 50},
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := embedTestHelper(ctx, client, t, req)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.PromptEvalCount > 50 {
|
|
||||||
t.Fatalf("expected tokens <= 50 after truncation, got %d", res.PromptEvalCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.PromptEvalCount == 0 {
|
|
||||||
t.Fatal("expected non-zero token count after truncation")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("truncation batch", func(t *testing.T) {
|
|
||||||
truncTrue := true
|
|
||||||
req := api.EmbedRequest{
|
|
||||||
Model: "all-minilm",
|
|
||||||
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
|
|
||||||
Truncate: &truncTrue,
|
|
||||||
Options: map[string]any{"num_ctx": 30},
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := embedTestHelper(ctx, client, t, req)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(res.Embeddings) != 3 {
|
|
||||||
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.PromptEvalCount > 90 {
|
|
||||||
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("truncate false error", func(t *testing.T) {
|
|
||||||
truncFalse := false
|
|
||||||
req := api.EmbedRequest{
|
|
||||||
Model: "all-minilm",
|
|
||||||
Input: strings.Repeat("word ", 100),
|
|
||||||
Truncate: &truncFalse,
|
|
||||||
Options: map[string]any{"num_ctx": 10},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := embedTestHelper(ctx, client, t, req)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error when truncate=false with long input")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(err.Error(), "exceeds maximum context length") {
|
|
||||||
t.Fatalf("expected context length error, got: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("runner token count accuracy", func(t *testing.T) {
|
|
||||||
baseline := api.EmbedRequest{Model: "all-minilm", Input: "test"}
|
|
||||||
baseRes, err := embedTestHelper(ctx, client, t, baseline)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
batch := api.EmbedRequest{
|
|
||||||
Model: "all-minilm",
|
|
||||||
Input: []string{"test", "test", "test"},
|
|
||||||
}
|
|
||||||
batchRes, err := embedTestHelper(ctx, client, t, batch)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedCount := baseRes.PromptEvalCount * 3
|
|
||||||
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
|
|
||||||
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
|
|
||||||
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -445,72 +368,77 @@ func TestEmbedStatusCode(t *testing.T) {
|
|||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
// Pull the model if needed
|
for _, model := range libraryEmbedModels {
|
||||||
if err := PullIfMissing(ctx, client, "all-minilm"); err != nil {
|
model := model
|
||||||
t.Fatal(err)
|
t.Run(model, func(t *testing.T) {
|
||||||
|
// Pull the model if needed
|
||||||
|
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("truncation error status code", func(t *testing.T) {
|
||||||
|
truncFalse := false
|
||||||
|
longInput := strings.Repeat("word ", 100)
|
||||||
|
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: model,
|
||||||
|
Input: longInput,
|
||||||
|
Truncate: &truncFalse,
|
||||||
|
Options: map[string]any{"num_ctx": 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := embedTestHelper(ctx, client, t, req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when truncate=false with long input")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that it's a StatusError with the correct status code
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if !errors.As(err, &statusErr) {
|
||||||
|
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The error should be a 4xx client error (likely 400 Bad Request)
|
||||||
|
// not a 500 Internal Server Error
|
||||||
|
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
||||||
|
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the error message is meaningful
|
||||||
|
if !strings.Contains(err.Error(), "context length") {
|
||||||
|
t.Errorf("expected error message to mention context length, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("batch truncation error status code", func(t *testing.T) {
|
||||||
|
truncFalse := false
|
||||||
|
req := api.EmbedRequest{
|
||||||
|
Model: model,
|
||||||
|
Input: []string{
|
||||||
|
"short input",
|
||||||
|
strings.Repeat("very long input ", 100),
|
||||||
|
"another short input",
|
||||||
|
},
|
||||||
|
Truncate: &truncFalse,
|
||||||
|
Options: map[string]any{"num_ctx": 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := embedTestHelper(ctx, client, t, req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when one input exceeds context with truncate=false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that it's a StatusError with the correct status code
|
||||||
|
var statusErr api.StatusError
|
||||||
|
if !errors.As(err, &statusErr) {
|
||||||
|
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The error should be a 4xx client error, not a 500 Internal Server Error
|
||||||
|
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
||||||
|
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("truncation error status code", func(t *testing.T) {
|
|
||||||
truncFalse := false
|
|
||||||
longInput := strings.Repeat("word ", 100)
|
|
||||||
|
|
||||||
req := api.EmbedRequest{
|
|
||||||
Model: "all-minilm",
|
|
||||||
Input: longInput,
|
|
||||||
Truncate: &truncFalse,
|
|
||||||
Options: map[string]any{"num_ctx": 10},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := embedTestHelper(ctx, client, t, req)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error when truncate=false with long input")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that it's a StatusError with the correct status code
|
|
||||||
var statusErr api.StatusError
|
|
||||||
if !errors.As(err, &statusErr) {
|
|
||||||
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The error should be a 4xx client error (likely 400 Bad Request)
|
|
||||||
// not a 500 Internal Server Error
|
|
||||||
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
|
||||||
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the error message is meaningful
|
|
||||||
if !strings.Contains(err.Error(), "context length") {
|
|
||||||
t.Errorf("expected error message to mention context length, got: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("batch truncation error status code", func(t *testing.T) {
|
|
||||||
truncFalse := false
|
|
||||||
req := api.EmbedRequest{
|
|
||||||
Model: "all-minilm",
|
|
||||||
Input: []string{
|
|
||||||
"short input",
|
|
||||||
strings.Repeat("very long input ", 100),
|
|
||||||
"another short input",
|
|
||||||
},
|
|
||||||
Truncate: &truncFalse,
|
|
||||||
Options: map[string]any{"num_ctx": 10},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := embedTestHelper(ctx, client, t, req)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error when one input exceeds context with truncate=false")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that it's a StatusError with the correct status code
|
|
||||||
var statusErr api.StatusError
|
|
||||||
if !errors.As(err, &statusErr) {
|
|
||||||
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The error should be a 4xx client error, not a 500 Internal Server Error
|
|
||||||
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
|
||||||
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user