embed: add distance correlation test for library embed models (#12796)

This commit is contained in:
Patrick Devine
2025-10-28 16:57:27 -07:00
committed by GitHub
parent d828517e78
commit 36d64fb531
2 changed files with 113 additions and 1 deletions

View File

@@ -14,6 +14,10 @@ import (
func dotProduct[V float32 | float64](v1, v2 []V) V { func dotProduct[V float32 | float64](v1, v2 []V) V {
var result V = 0 var result V = 0
if len(v1) != len(v2) {
return result
}
for i := 0; i < len(v1); i++ { for i := 0; i < len(v1); i++ {
result += v1[i] * v2[i] result += v1[i] * v2[i]
} }
@@ -29,9 +33,115 @@ func magnitude[V float32 | float64](v []V) V {
} }
func cosineSimilarity[V float32 | float64](v1, v2 []V) V { func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
mag1 := magnitude(v1)
mag2 := magnitude(v2)
if mag1 == 0 || mag2 == 0 {
return 0
}
return dotProduct(v1, v2) / (magnitude(v1) * magnitude(v2)) return dotProduct(v1, v2) / (magnitude(v1) * magnitude(v2))
} }
func euclideanDistance[V float32 | float64](v1, v2 []V) V {
if len(v1) != len(v2) {
return V(math.Inf(1))
}
var sum V = 0
for i := 0; i < len(v1); i++ {
diff := v1[i] - v2[i]
sum += diff * diff
}
return V(math.Sqrt(float64(sum)))
}
func manhattanDistance[V float32 | float64](v1, v2 []V) V {
if len(v1) != len(v2) {
return V(math.Inf(1))
}
var sum V = 0
for i := 0; i < len(v1); i++ {
sum += V(math.Abs(float64(v1[i] - v2[i])))
}
return sum
}
func TestEmbedCosineDistanceCorrelation(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, model := range libraryEmbedModels {
t.Run(model, func(t *testing.T) {
testCases := []struct {
a string
b string
c string
}{
{"cat", "kitten", "dog"},
{"king", "queen", "baron"},
{"paris", "london", "vancouver"},
{"The cat is sleeping on the sofa", "A feline is sleeping on the couch", "Quantum physics is complex"},
{"I love programming in python", "Coding in python brings me joy", "Pizza is delicious"},
{"Machine learning is fascinating", "Artificial intelligence is amazing", "I need to buy groceries"},
{"The quick brown fox jumps over the lazy dog", "A fast brown fox leaps over a sleepy dog", "The weather is warm and sunny today"},
}
for _, tc := range testCases {
testEmbed := make(map[string][]float32)
strs := []string{tc.a, tc.b, tc.c}
req := api.EmbedRequest{
Model: model,
Input: strs,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
}
resp, err := embedTestHelper(ctx, client, t, req)
if err != nil {
t.Fatal(err)
}
for cnt, v := range resp.Embeddings {
testEmbed[strs[cnt]] = v
}
// Calculate cosine similarities
cosAB := cosineSimilarity(testEmbed[tc.a], testEmbed[tc.b])
cosAC := cosineSimilarity(testEmbed[tc.a], testEmbed[tc.c])
// Calculate distances
distAB := euclideanDistance(testEmbed[tc.a], testEmbed[tc.b])
distAC := euclideanDistance(testEmbed[tc.a], testEmbed[tc.c])
manhattanAB := manhattanDistance(testEmbed[tc.a], testEmbed[tc.b])
manhattanAC := manhattanDistance(testEmbed[tc.a], testEmbed[tc.c])
// Consistency check: if cosAB > cosAC, then distances should be smaller
if cosAB > cosAC {
if distAB >= distAC {
t.Errorf("Euclidean distance inconsistency (%s) for %s-%s-%s: cosAB=%f > cosAC=%f but distAB=%f >= distAC=%f",
model, tc.a, tc.b, tc.c, cosAB, cosAC, distAB, distAC)
}
if manhattanAB >= manhattanAC {
t.Errorf("Manhattan distance inconsistency (%s) for %s-%s-%s: cosAB=%f > cosAC=%f but manhattanAB=%f >= manhattanAC=%f",
model, tc.a, tc.b, tc.c, cosAB, cosAC, manhattanAB, manhattanAC)
}
} else {
t.Errorf("Cosine Similarity inconsistency (%s): cosinSim(%s, %s) < cosinSim(%s, %s)",
model, tc.a, tc.b, tc.a, tc.c)
}
}
})
}
}
func TestAllMiniLMEmbeddings(t *testing.T) { func TestAllMiniLMEmbeddings(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()

View File

@@ -248,12 +248,14 @@ var (
"zephyr", "zephyr",
} }
libraryEmbedModels = []string{ libraryEmbedModels = []string{
"qwen3-embedding",
"embeddinggemma",
"nomic-embed-text",
"all-minilm", "all-minilm",
"bge-large", "bge-large",
"bge-m3", "bge-m3",
"granite-embedding", "granite-embedding",
"mxbai-embed-large", "mxbai-embed-large",
"nomic-embed-text",
"paraphrase-multilingual", "paraphrase-multilingual",
"snowflake-arctic-embed", "snowflake-arctic-embed",
"snowflake-arctic-embed2", "snowflake-arctic-embed2",