mirror of
https://github.com/ollama/ollama.git
synced 2025-11-11 07:37:34 +01:00
embedding tests: added check against exact base64 string (#12790)
This commit is contained in:
@@ -10,19 +10,20 @@ import (
|
|||||||
|
|
||||||
func TestToEmbeddingList(t *testing.T) {
|
func TestToEmbeddingList(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
embeddings [][]float32
|
embeddings [][]float32
|
||||||
format string
|
format string
|
||||||
expectType string // "float" or "base64"
|
expectType string // "float" or "base64"
|
||||||
expectCount int
|
expectBase64 []string
|
||||||
promptEval int
|
expectCount int
|
||||||
|
promptEval int
|
||||||
}{
|
}{
|
||||||
{"float format", [][]float32{{0.1, -0.2, 0.3}}, "float", "float", 1, 10},
|
{"float format", [][]float32{{0.1, -0.2, 0.3}}, "float", "float", nil, 1, 10},
|
||||||
{"base64 format", [][]float32{{0.1, -0.2, 0.3}}, "base64", "base64", 1, 5},
|
{"base64 format", [][]float32{{0.1, -0.2, 0.3}}, "base64", "base64", []string{"zczMPc3MTL6amZk+"}, 1, 5},
|
||||||
{"default to float", [][]float32{{0.1, -0.2, 0.3}}, "", "float", 1, 0},
|
{"default to float", [][]float32{{0.1, -0.2, 0.3}}, "", "float", nil, 1, 0},
|
||||||
{"invalid defaults to float", [][]float32{{0.1, -0.2, 0.3}}, "invalid", "float", 1, 0},
|
{"invalid defaults to float", [][]float32{{0.1, -0.2, 0.3}}, "invalid", "float", nil, 1, 0},
|
||||||
{"multiple embeddings", [][]float32{{0.1, 0.2}, {0.3, 0.4}, {0.5, 0.6}}, "base64", "base64", 3, 0},
|
{"multiple embeddings", [][]float32{{0.1, 0.2}, {0.3, 0.4}, {0.5, 0.6}}, "base64", "base64", []string{"zczMPc3MTD4=", "mpmZPs3MzD4=", "AAAAP5qZGT8="}, 3, 0},
|
||||||
{"empty embeddings", nil, "float", "", 0, 0},
|
{"empty embeddings", nil, "float", "", nil, 0, 0},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
@@ -56,11 +57,24 @@ func TestToEmbeddingList(t *testing.T) {
|
|||||||
t.Errorf("expected []float32, got %T", result.Data[0].Embedding)
|
t.Errorf("expected []float32, got %T", result.Data[0].Embedding)
|
||||||
}
|
}
|
||||||
case "base64":
|
case "base64":
|
||||||
embStr, ok := result.Data[0].Embedding.(string)
|
for i, data := range result.Data {
|
||||||
if !ok {
|
embStr, ok := data.Embedding.(string)
|
||||||
t.Errorf("expected string, got %T", result.Data[0].Embedding)
|
if !ok {
|
||||||
} else if _, err := base64.StdEncoding.DecodeString(embStr); err != nil {
|
t.Errorf("embedding %d: expected string, got %T", i, data.Embedding)
|
||||||
t.Errorf("invalid base64: %v", err)
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's valid base64
|
||||||
|
if _, err := base64.StdEncoding.DecodeString(embStr); err != nil {
|
||||||
|
t.Errorf("embedding %d: invalid base64: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare against expected base64 string if provided
|
||||||
|
if tc.expectBase64 != nil && i < len(tc.expectBase64) {
|
||||||
|
if embStr != tc.expectBase64[i] {
|
||||||
|
t.Errorf("embedding %d: expected base64 %q, got %q", i, tc.expectBase64[i], embStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user