mirror of
https://github.com/ollama/ollama.git
synced 2025-11-11 01:18:01 +01:00
fix(integration): check truncated length (#12337)
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -44,9 +45,8 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
res, err := embeddingTestHelper(ctx, client, t, req)
|
res, err := embeddingTestHelper(ctx, client, t, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(res.Embedding) != 384 {
|
if len(res.Embedding) != 384 {
|
||||||
@@ -74,9 +74,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
res, err := embedTestHelper(ctx, client, t, req)
|
res, err := embedTestHelper(ctx, client, t, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(res.Embeddings) != 1 {
|
if len(res.Embeddings) != 1 {
|
||||||
@@ -112,9 +111,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
res, err := embedTestHelper(ctx, client, t, req)
|
res, err := embedTestHelper(ctx, client, t, req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error: %v", err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(res.Embeddings) != 2 {
|
if len(res.Embeddings) != 2 {
|
||||||
@@ -156,93 +154,135 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
|||||||
|
|
||||||
truncTrue, truncFalse := true, false
|
truncTrue, truncFalse := true, false
|
||||||
|
|
||||||
type testReq struct {
|
want, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
|
||||||
Name string
|
Model: "all-minilm",
|
||||||
Request api.EmbedRequest
|
Input: "why",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reqs := []testReq{
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
request api.EmbedRequest
|
||||||
|
check func(*api.EmbedResponse, error)
|
||||||
|
}{
|
||||||
{
|
{
|
||||||
Name: "Target Truncation",
|
name: "target truncation",
|
||||||
Request: api.EmbedRequest{
|
request: api.EmbedRequest{
|
||||||
Model: "all-minilm",
|
Model: "all-minilm",
|
||||||
Input: "why",
|
Input: "why",
|
||||||
},
|
},
|
||||||
},
|
check: func(got *api.EmbedResponse, err error) {
|
||||||
{
|
if err != nil {
|
||||||
Name: "Default Truncate",
|
t.Fatal(err)
|
||||||
Request: api.EmbedRequest{
|
}
|
||||||
Model: "all-minilm",
|
|
||||||
Input: "why is the sky blue?",
|
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||||
Options: map[string]any{"num_ctx": 1},
|
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "Explicit Truncate",
|
name: "default truncate",
|
||||||
Request: api.EmbedRequest{
|
request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Options: map[string]any{"num_ctx": 3},
|
||||||
|
},
|
||||||
|
check: func(got *api.EmbedResponse, err error) {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||||
|
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit truncate",
|
||||||
|
request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Truncate: &truncTrue,
|
||||||
|
Options: map[string]any{"num_ctx": 3},
|
||||||
|
},
|
||||||
|
check: func(got *api.EmbedResponse, err error) {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||||
|
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "truncate error",
|
||||||
|
request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Truncate: &truncFalse,
|
||||||
|
Options: map[string]any{"num_ctx": 3},
|
||||||
|
},
|
||||||
|
check: func(res *api.EmbedResponse, err error) {
|
||||||
|
if err.Error() != "input exceeds maximum context length" {
|
||||||
|
t.Fatalf("expected truncation error, got: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "input after truncate error",
|
||||||
|
request: api.EmbedRequest{
|
||||||
Model: "all-minilm",
|
Model: "all-minilm",
|
||||||
Input: "why is the sky blue?",
|
Input: "why is the sky blue?",
|
||||||
Truncate: &truncTrue,
|
Truncate: &truncTrue,
|
||||||
Options: map[string]any{"num_ctx": 1},
|
Options: map[string]any{"num_ctx": 1},
|
||||||
},
|
},
|
||||||
|
check: func(res *api.EmbedResponse, err error) {
|
||||||
|
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||||
|
t.Fatalf("expected truncation error, got: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "input after truncate error",
|
||||||
|
request: api.EmbedRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Input: "why is the sky blue?",
|
||||||
|
Truncate: &truncTrue,
|
||||||
|
Options: map[string]any{"num_ctx": 0},
|
||||||
|
},
|
||||||
|
check: func(res *api.EmbedResponse, err error) {
|
||||||
|
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||||
|
t.Fatalf("expected truncation error, got: %v", err)
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
res := make(map[string]*api.EmbedResponse)
|
for _, req := range cases {
|
||||||
|
t.Run(req.name, func(t *testing.T) {
|
||||||
for _, req := range reqs {
|
req.check(embedTestHelper(ctx, client, t, req.request))
|
||||||
response, err := embedTestHelper(ctx, client, t, req.Request)
|
})
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("error: %v", err)
|
|
||||||
}
|
|
||||||
res[req.Name] = response
|
|
||||||
}
|
|
||||||
|
|
||||||
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
|
|
||||||
t.Fatal("expected default request to truncate correctly")
|
|
||||||
}
|
|
||||||
|
|
||||||
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
|
|
||||||
t.Fatal("expected default request and truncate true request to be the same")
|
|
||||||
}
|
|
||||||
|
|
||||||
// check that truncate set to false returns an error if context length is exceeded
|
|
||||||
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
|
|
||||||
Model: "all-minilm",
|
|
||||||
Input: "why is the sky blue?",
|
|
||||||
Truncate: &truncFalse,
|
|
||||||
Options: map[string]any{"num_ctx": 1},
|
|
||||||
})
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error, got nil")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := client.Embeddings(ctx, &req)
|
return client.Embeddings(ctx, &req)
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := client.Embed(ctx, &req)
|
return client.Embed(ctx, &req)
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -634,7 +634,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||||
if len(tokens) > ctxLen {
|
if len(tokens) > ctxLen {
|
||||||
if !truncate {
|
if !truncate {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -646,6 +646,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
ctxLen--
|
ctxLen--
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
|
||||||
|
if ctxLen <= 0 {
|
||||||
|
// return error if the truncated input would be empty or just special tokens
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
tokens = tokens[:ctxLen]
|
tokens = tokens[:ctxLen]
|
||||||
|
|
||||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||||
|
|||||||
Reference in New Issue
Block a user