embedding tests: added check against exact base64 string (#12790)

This commit is contained in:
nicole pardal
2025-10-28 10:37:20 -07:00
committed by GitHub
parent 9862317174
commit 15c7d30d9a

View File

@@ -14,15 +14,16 @@ func TestToEmbeddingList(t *testing.T) {
embeddings [][]float32
format string
expectType string // "float" or "base64"
expectBase64 []string
expectCount int
promptEval int
}{
{"float format", [][]float32{{0.1, -0.2, 0.3}}, "float", "float", 1, 10},
{"base64 format", [][]float32{{0.1, -0.2, 0.3}}, "base64", "base64", 1, 5},
{"default to float", [][]float32{{0.1, -0.2, 0.3}}, "", "float", 1, 0},
{"invalid defaults to float", [][]float32{{0.1, -0.2, 0.3}}, "invalid", "float", 1, 0},
{"multiple embeddings", [][]float32{{0.1, 0.2}, {0.3, 0.4}, {0.5, 0.6}}, "base64", "base64", 3, 0},
{"empty embeddings", nil, "float", "", 0, 0},
{"float format", [][]float32{{0.1, -0.2, 0.3}}, "float", "float", nil, 1, 10},
{"base64 format", [][]float32{{0.1, -0.2, 0.3}}, "base64", "base64", []string{"zczMPc3MTL6amZk+"}, 1, 5},
{"default to float", [][]float32{{0.1, -0.2, 0.3}}, "", "float", nil, 1, 0},
{"invalid defaults to float", [][]float32{{0.1, -0.2, 0.3}}, "invalid", "float", nil, 1, 0},
{"multiple embeddings", [][]float32{{0.1, 0.2}, {0.3, 0.4}, {0.5, 0.6}}, "base64", "base64", []string{"zczMPc3MTD4=", "mpmZPs3MzD4=", "AAAAP5qZGT8="}, 3, 0},
{"empty embeddings", nil, "float", "", nil, 0, 0},
}
for _, tc := range testCases {
@@ -56,11 +57,24 @@ func TestToEmbeddingList(t *testing.T) {
t.Errorf("expected []float32, got %T", result.Data[0].Embedding)
}
case "base64":
embStr, ok := result.Data[0].Embedding.(string)
for i, data := range result.Data {
embStr, ok := data.Embedding.(string)
if !ok {
t.Errorf("expected string, got %T", result.Data[0].Embedding)
} else if _, err := base64.StdEncoding.DecodeString(embStr); err != nil {
t.Errorf("invalid base64: %v", err)
t.Errorf("embedding %d: expected string, got %T", i, data.Embedding)
continue
}
// Verify it's valid base64
if _, err := base64.StdEncoding.DecodeString(embStr); err != nil {
t.Errorf("embedding %d: invalid base64: %v", i, err)
}
// Compare against expected base64 string if provided
if tc.expectBase64 != nil && i < len(tc.expectBase64) {
if embStr != tc.expectBase64[i] {
t.Errorf("embedding %d: expected base64 %q, got %q", i, tc.expectBase64[i], embStr)
}
}
}
}