From 36d64fb5314327302ba25853f45d4b3c4d66c768 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Tue, 28 Oct 2025 16:57:27 -0700 Subject: [PATCH] embed: add distance correlation test for library embed models (#12796) --- integration/embed_test.go | 110 ++++++++++++++++++++++++++++++++++++++ integration/utils_test.go | 4 +- 2 files changed, 113 insertions(+), 1 deletion(-) diff --git a/integration/embed_test.go b/integration/embed_test.go index 3a8bcd2482..e155498dbf 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -14,6 +14,10 @@ import ( func dotProduct[V float32 | float64](v1, v2 []V) V { var result V = 0 + if len(v1) != len(v2) { + return result + } + for i := 0; i < len(v1); i++ { result += v1[i] * v2[i] } @@ -29,9 +33,115 @@ func magnitude[V float32 | float64](v []V) V { } func cosineSimilarity[V float32 | float64](v1, v2 []V) V { + mag1 := magnitude(v1) + mag2 := magnitude(v2) + + if mag1 == 0 || mag2 == 0 { + return 0 + } + return dotProduct(v1, v2) / (magnitude(v1) * magnitude(v2)) } +func euclideanDistance[V float32 | float64](v1, v2 []V) V { + if len(v1) != len(v2) { + return V(math.Inf(1)) + } + + var sum V = 0 + for i := 0; i < len(v1); i++ { + diff := v1[i] - v2[i] + sum += diff * diff + } + + return V(math.Sqrt(float64(sum))) +} + +func manhattanDistance[V float32 | float64](v1, v2 []V) V { + if len(v1) != len(v2) { + return V(math.Inf(1)) + } + + var sum V = 0 + for i := 0; i < len(v1); i++ { + sum += V(math.Abs(float64(v1[i] - v2[i]))) + } + + return sum +} + +func TestEmbedCosineDistanceCorrelation(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + for _, model := range libraryEmbedModels { + t.Run(model, func(t *testing.T) { + testCases := []struct { + a string + b string + c string + }{ + {"cat", "kitten", "dog"}, + {"king", "queen", "baron"}, + {"paris", "london", "vancouver"}, + {"The cat is sleeping on the sofa", "A feline is sleeping on the couch", "Quantum physics is complex"}, + {"I love programming in python", "Coding in python brings me joy", "Pizza is delicious"}, + {"Machine learning is fascinating", "Artificial intelligence is amazing", "I need to buy groceries"}, + {"The quick brown fox jumps over the lazy dog", "A fast brown fox leaps over a sleepy dog", "The weather is warm and sunny today"}, + } + + for _, tc := range testCases { + testEmbed := make(map[string][]float32) + strs := []string{tc.a, tc.b, tc.c} + + req := api.EmbedRequest{ + Model: model, + Input: strs, + KeepAlive: &api.Duration{Duration: 10 * time.Second}, + } + + resp, err := embedTestHelper(ctx, client, t, req) + if err != nil { + t.Fatal(err) + } + + for cnt, v := range resp.Embeddings { + testEmbed[strs[cnt]] = v + } + + // Calculate cosine similarities + cosAB := cosineSimilarity(testEmbed[tc.a], testEmbed[tc.b]) + cosAC := cosineSimilarity(testEmbed[tc.a], testEmbed[tc.c]) + + // Calculate distances + distAB := euclideanDistance(testEmbed[tc.a], testEmbed[tc.b]) + distAC := euclideanDistance(testEmbed[tc.a], testEmbed[tc.c]) + + manhattanAB := manhattanDistance(testEmbed[tc.a], testEmbed[tc.b]) + manhattanAC := manhattanDistance(testEmbed[tc.a], testEmbed[tc.c]) + + // Consistency check: if cosAB > cosAC, then distances should be smaller + if cosAB > cosAC { + if distAB >= distAC { + t.Errorf("Euclidean distance inconsistency (%s) for %s-%s-%s: cosAB=%f > cosAC=%f but distAB=%f >= distAC=%f", + model, tc.a, tc.b, tc.c, cosAB, cosAC, distAB, distAC) + } + + if manhattanAB >= manhattanAC { + t.Errorf("Manhattan distance inconsistency (%s) for %s-%s-%s: cosAB=%f > cosAC=%f but manhattanAB=%f >= manhattanAC=%f", + model, tc.a, tc.b, tc.c, cosAB, cosAC, manhattanAB, manhattanAC) + } + } else { + t.Errorf("Cosine Similarity inconsistency (%s): cosinSim(%s, %s) < cosinSim(%s, %s)", + model, tc.a, tc.b, tc.a, tc.c) + } + } + }) + } +} + func TestAllMiniLMEmbeddings(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() diff --git a/integration/utils_test.go b/integration/utils_test.go index c438aa9306..66e8d73165 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -248,12 +248,14 @@ var ( "zephyr", } libraryEmbedModels = []string{ + "qwen3-embedding", + "embeddinggemma", + "nomic-embed-text", "all-minilm", "bge-large", "bge-m3", "granite-embedding", "mxbai-embed-large", - "nomic-embed-text", "paraphrase-multilingual", "snowflake-arctic-embed", "snowflake-arctic-embed2",