mirror of
https://github.com/ollama/ollama.git
synced 2025-11-11 02:37:49 +01:00
embed: add distance correlation test for library embed models (#12796)
This commit is contained in:
@@ -14,6 +14,10 @@ import (
|
||||
|
||||
func dotProduct[V float32 | float64](v1, v2 []V) V {
|
||||
var result V = 0
|
||||
if len(v1) != len(v2) {
|
||||
return result
|
||||
}
|
||||
|
||||
for i := 0; i < len(v1); i++ {
|
||||
result += v1[i] * v2[i]
|
||||
}
|
||||
@@ -29,9 +33,115 @@ func magnitude[V float32 | float64](v []V) V {
|
||||
}
|
||||
|
||||
func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
|
||||
mag1 := magnitude(v1)
|
||||
mag2 := magnitude(v2)
|
||||
|
||||
if mag1 == 0 || mag2 == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return dotProduct(v1, v2) / (magnitude(v1) * magnitude(v2))
|
||||
}
|
||||
|
||||
func euclideanDistance[V float32 | float64](v1, v2 []V) V {
|
||||
if len(v1) != len(v2) {
|
||||
return V(math.Inf(1))
|
||||
}
|
||||
|
||||
var sum V = 0
|
||||
for i := 0; i < len(v1); i++ {
|
||||
diff := v1[i] - v2[i]
|
||||
sum += diff * diff
|
||||
}
|
||||
|
||||
return V(math.Sqrt(float64(sum)))
|
||||
}
|
||||
|
||||
func manhattanDistance[V float32 | float64](v1, v2 []V) V {
|
||||
if len(v1) != len(v2) {
|
||||
return V(math.Inf(1))
|
||||
}
|
||||
|
||||
var sum V = 0
|
||||
for i := 0; i < len(v1); i++ {
|
||||
sum += V(math.Abs(float64(v1[i] - v2[i])))
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
func TestEmbedCosineDistanceCorrelation(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for _, model := range libraryEmbedModels {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
a string
|
||||
b string
|
||||
c string
|
||||
}{
|
||||
{"cat", "kitten", "dog"},
|
||||
{"king", "queen", "baron"},
|
||||
{"paris", "london", "vancouver"},
|
||||
{"The cat is sleeping on the sofa", "A feline is sleeping on the couch", "Quantum physics is complex"},
|
||||
{"I love programming in python", "Coding in python brings me joy", "Pizza is delicious"},
|
||||
{"Machine learning is fascinating", "Artificial intelligence is amazing", "I need to buy groceries"},
|
||||
{"The quick brown fox jumps over the lazy dog", "A fast brown fox leaps over a sleepy dog", "The weather is warm and sunny today"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
testEmbed := make(map[string][]float32)
|
||||
strs := []string{tc.a, tc.b, tc.c}
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: model,
|
||||
Input: strs,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
}
|
||||
|
||||
resp, err := embedTestHelper(ctx, client, t, req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for cnt, v := range resp.Embeddings {
|
||||
testEmbed[strs[cnt]] = v
|
||||
}
|
||||
|
||||
// Calculate cosine similarities
|
||||
cosAB := cosineSimilarity(testEmbed[tc.a], testEmbed[tc.b])
|
||||
cosAC := cosineSimilarity(testEmbed[tc.a], testEmbed[tc.c])
|
||||
|
||||
// Calculate distances
|
||||
distAB := euclideanDistance(testEmbed[tc.a], testEmbed[tc.b])
|
||||
distAC := euclideanDistance(testEmbed[tc.a], testEmbed[tc.c])
|
||||
|
||||
manhattanAB := manhattanDistance(testEmbed[tc.a], testEmbed[tc.b])
|
||||
manhattanAC := manhattanDistance(testEmbed[tc.a], testEmbed[tc.c])
|
||||
|
||||
// Consistency check: if cosAB > cosAC, then distances should be smaller
|
||||
if cosAB > cosAC {
|
||||
if distAB >= distAC {
|
||||
t.Errorf("Euclidean distance inconsistency (%s) for %s-%s-%s: cosAB=%f > cosAC=%f but distAB=%f >= distAC=%f",
|
||||
model, tc.a, tc.b, tc.c, cosAB, cosAC, distAB, distAC)
|
||||
}
|
||||
|
||||
if manhattanAB >= manhattanAC {
|
||||
t.Errorf("Manhattan distance inconsistency (%s) for %s-%s-%s: cosAB=%f > cosAC=%f but manhattanAB=%f >= manhattanAC=%f",
|
||||
model, tc.a, tc.b, tc.c, cosAB, cosAC, manhattanAB, manhattanAC)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("Cosine Similarity inconsistency (%s): cosinSim(%s, %s) < cosinSim(%s, %s)",
|
||||
model, tc.a, tc.b, tc.a, tc.c)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
@@ -248,12 +248,14 @@ var (
|
||||
"zephyr",
|
||||
}
|
||||
libraryEmbedModels = []string{
|
||||
"qwen3-embedding",
|
||||
"embeddinggemma",
|
||||
"nomic-embed-text",
|
||||
"all-minilm",
|
||||
"bge-large",
|
||||
"bge-m3",
|
||||
"granite-embedding",
|
||||
"mxbai-embed-large",
|
||||
"nomic-embed-text",
|
||||
"paraphrase-multilingual",
|
||||
"snowflake-arctic-embed",
|
||||
"snowflake-arctic-embed2",
|
||||
|
||||
Reference in New Issue
Block a user