Files
ollama/integration/embed_truncation_test.go
2025-10-23 16:28:55 -07:00

146 lines
3.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//go:build integration
package integration
import (
"context"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestEmbedTruncationRefactor(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
t.Run("single input token count", func(t *testing.T) {
req := api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if res.PromptEvalCount <= 0 {
t.Fatalf("expected positive token count, got %d", res.PromptEvalCount)
}
})
t.Run("batch parallel token counting", func(t *testing.T) {
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"cat", "dog and mouse", "bird"},
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if len(res.Embeddings) != 3 {
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
}
if res.PromptEvalCount <= 0 {
t.Fatalf("expected positive token count, got %d", res.PromptEvalCount)
}
})
t.Run("truncation single input", func(t *testing.T) {
truncTrue := true
longInput := strings.Repeat("word ", 100)
req := api.EmbedRequest{
Model: "all-minilm",
Input: longInput,
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 50},
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if res.PromptEvalCount > 50 {
t.Fatalf("expected tokens <= 50 after truncation, got %d", res.PromptEvalCount)
}
if res.PromptEvalCount == 0 {
t.Fatal("expected non-zero token count after truncation")
}
})
t.Run("truncation batch", func(t *testing.T) {
truncTrue := true
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 30},
}
res, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if len(res.Embeddings) != 3 {
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
}
if res.PromptEvalCount > 90 {
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
}
})
t.Run("truncate false error", func(t *testing.T) {
truncFalse := false
req := api.EmbedRequest{
Model: "all-minilm",
Input: strings.Repeat("word ", 100),
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(ctx, client, t, req)
if err == nil {
t.Fatal("expected error when truncate=false with long input")
}
if !strings.Contains(err.Error(), "exceeds maximum context length") {
t.Fatalf("expected context length error, got: %v", err)
}
})
t.Run("runner token count accuracy", func(t *testing.T) {
baseline := api.EmbedRequest{Model: "all-minilm", Input: "test"}
baseRes, err := embedTestHelper(ctx, client, t, baseline)
if err != nil {
t.Fatal(err)
}
batch := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"test", "test", "test"},
}
batchRes, err := embedTestHelper(ctx, client, t, batch)
if err != nil {
t.Fatal(err)
}
expectedCount := baseRes.PromptEvalCount * 3
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
}
})
}