fix(integration): check truncated length (#12337)

This commit is contained in:
Michael Yang
2025-09-18 14:00:21 -07:00
committed by GitHub
parent 2717dce6fe
commit ceac416ec2
2 changed files with 113 additions and 66 deletions

View File

@@ -8,6 +8,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
@@ -44,9 +45,8 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
} }
res, err := embeddingTestHelper(ctx, client, t, req) res, err := embeddingTestHelper(ctx, client, t, req)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatal(err)
} }
if len(res.Embedding) != 384 { if len(res.Embedding) != 384 {
@@ -74,9 +74,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
} }
res, err := embedTestHelper(ctx, client, t, req) res, err := embedTestHelper(ctx, client, t, req)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatal(err)
} }
if len(res.Embeddings) != 1 { if len(res.Embeddings) != 1 {
@@ -112,9 +111,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
} }
res, err := embedTestHelper(ctx, client, t, req) res, err := embedTestHelper(ctx, client, t, req)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatal(err)
} }
if len(res.Embeddings) != 2 { if len(res.Embeddings) != 2 {
@@ -156,93 +154,135 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
truncTrue, truncFalse := true, false truncTrue, truncFalse := true, false
type testReq struct { want, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
Name string Model: "all-minilm",
Request api.EmbedRequest Input: "why",
})
if err != nil {
t.Fatal(err)
} }
reqs := []testReq{ cases := []struct {
name string
request api.EmbedRequest
check func(*api.EmbedResponse, error)
}{
{ {
Name: "Target Truncation", name: "target truncation",
Request: api.EmbedRequest{ request: api.EmbedRequest{
Model: "all-minilm", Model: "all-minilm",
Input: "why", Input: "why",
}, },
}, check: func(got *api.EmbedResponse, err error) {
{ if err != nil {
Name: "Default Truncate", t.Fatal(err)
Request: api.EmbedRequest{ }
Model: "all-minilm",
Input: "why is the sky blue?", if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
Options: map[string]any{"num_ctx": 1}, t.Errorf("embedding mismatch (-want +got):\n%s", diff)
}
}, },
}, },
{ {
Name: "Explicit Truncate", name: "default truncate",
Request: api.EmbedRequest{ request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Options: map[string]any{"num_ctx": 3},
},
check: func(got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
}
},
},
{
name: "explicit truncate",
request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 3},
},
check: func(got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
}
},
},
{
name: "truncate error",
request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 3},
},
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
},
},
{
name: "input after truncate error",
request: api.EmbedRequest{
Model: "all-minilm", Model: "all-minilm",
Input: "why is the sky blue?", Input: "why is the sky blue?",
Truncate: &truncTrue, Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 1}, Options: map[string]any{"num_ctx": 1},
}, },
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input after truncation exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
},
},
{
name: "input after truncate error",
request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 0},
},
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input after truncation exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
},
}, },
} }
res := make(map[string]*api.EmbedResponse) for _, req := range cases {
t.Run(req.name, func(t *testing.T) {
for _, req := range reqs { req.check(embedTestHelper(ctx, client, t, req.request))
response, err := embedTestHelper(ctx, client, t, req.Request) })
if err != nil {
t.Fatalf("error: %v", err)
}
res[req.Name] = response
}
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
t.Fatal("expected default request to truncate correctly")
}
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
t.Fatal("expected default request and truncate true request to be the same")
}
// check that truncate set to false returns an error if context length is exceeded
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 1},
})
if err == nil {
t.Fatal("expected error, got nil")
} }
} }
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) { func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
t.Helper()
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err) t.Fatal(err)
} }
response, err := client.Embeddings(ctx, &req) return client.Embeddings(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
} }
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
t.Helper()
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err) t.Fatal(err)
} }
response, err := client.Embed(ctx, &req) return client.Embed(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
} }

View File

@@ -634,7 +634,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen { if len(tokens) > ctxLen {
if !truncate { if !truncate {
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"}) c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"})
return return
} }
@@ -646,6 +646,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
ctxLen-- ctxLen--
} }
slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
if ctxLen <= 0 {
// return error if the truncated input would be empty or just special tokens
c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
return
}
tokens = tokens[:ctxLen] tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens) s, err = r.Detokenize(c.Request.Context(), tokens)