mirror of
https://github.com/ollama/ollama.git
synced 2025-11-11 07:37:34 +01:00
embedding tests: added check against exact base64 string (#12790)
This commit is contained in:
@@ -14,15 +14,16 @@ func TestToEmbeddingList(t *testing.T) {
|
||||
embeddings [][]float32
|
||||
format string
|
||||
expectType string // "float" or "base64"
|
||||
expectBase64 []string
|
||||
expectCount int
|
||||
promptEval int
|
||||
}{
|
||||
{"float format", [][]float32{{0.1, -0.2, 0.3}}, "float", "float", 1, 10},
|
||||
{"base64 format", [][]float32{{0.1, -0.2, 0.3}}, "base64", "base64", 1, 5},
|
||||
{"default to float", [][]float32{{0.1, -0.2, 0.3}}, "", "float", 1, 0},
|
||||
{"invalid defaults to float", [][]float32{{0.1, -0.2, 0.3}}, "invalid", "float", 1, 0},
|
||||
{"multiple embeddings", [][]float32{{0.1, 0.2}, {0.3, 0.4}, {0.5, 0.6}}, "base64", "base64", 3, 0},
|
||||
{"empty embeddings", nil, "float", "", 0, 0},
|
||||
{"float format", [][]float32{{0.1, -0.2, 0.3}}, "float", "float", nil, 1, 10},
|
||||
{"base64 format", [][]float32{{0.1, -0.2, 0.3}}, "base64", "base64", []string{"zczMPc3MTL6amZk+"}, 1, 5},
|
||||
{"default to float", [][]float32{{0.1, -0.2, 0.3}}, "", "float", nil, 1, 0},
|
||||
{"invalid defaults to float", [][]float32{{0.1, -0.2, 0.3}}, "invalid", "float", nil, 1, 0},
|
||||
{"multiple embeddings", [][]float32{{0.1, 0.2}, {0.3, 0.4}, {0.5, 0.6}}, "base64", "base64", []string{"zczMPc3MTD4=", "mpmZPs3MzD4=", "AAAAP5qZGT8="}, 3, 0},
|
||||
{"empty embeddings", nil, "float", "", nil, 0, 0},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -56,11 +57,24 @@ func TestToEmbeddingList(t *testing.T) {
|
||||
t.Errorf("expected []float32, got %T", result.Data[0].Embedding)
|
||||
}
|
||||
case "base64":
|
||||
embStr, ok := result.Data[0].Embedding.(string)
|
||||
for i, data := range result.Data {
|
||||
embStr, ok := data.Embedding.(string)
|
||||
if !ok {
|
||||
t.Errorf("expected string, got %T", result.Data[0].Embedding)
|
||||
} else if _, err := base64.StdEncoding.DecodeString(embStr); err != nil {
|
||||
t.Errorf("invalid base64: %v", err)
|
||||
t.Errorf("embedding %d: expected string, got %T", i, data.Embedding)
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify it's valid base64
|
||||
if _, err := base64.StdEncoding.DecodeString(embStr); err != nil {
|
||||
t.Errorf("embedding %d: invalid base64: %v", i, err)
|
||||
}
|
||||
|
||||
// Compare against expected base64 string if provided
|
||||
if tc.expectBase64 != nil && i < len(tc.expectBase64) {
|
||||
if embStr != tc.expectBase64[i] {
|
||||
t.Errorf("embedding %d: expected base64 %q, got %q", i, tc.expectBase64[i], embStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user